def create_a_shadows_b( name_a: str, gm_a: GraphModule, name_b: str, gm_b: GraphModule, matched_subgraph_pairs: Dict[str, Tuple[Tuple[Node, Node], Tuple[Node, Node]]], logger_cls: Callable, ) -> GraphModule: """ Creates a new GraphModule consisting of the graph of C, with the meaningful nodes of A shadowing the corresponding nodes of B. For example, Graph A: a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 Graph B: b0 -> op0_int8 -> b1 -> op1_int8 -> b2 matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} Graph C (A shadows B): / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 / / b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 In a nutshell, this function does the following for each node pair: * copies the necessary attributes and modules from gm_a to gm_b, keeping names unique * adds a dtype cast op (dequant, quant, etc) * adds a copy of node_a in gm_b's graph * adds loggers to the outputs of node_a and node_b """ # graph_c is the graph created from copying the nodes of graph_b and inserting # the shadows with the nodes copied from graph_a graph_c = Graph() env_c: Dict[str, Any] = {} modules = dict(gm_b.named_modules()) def load_arg(a): return map_arg(a, lambda node: env_c[node.name]) node_b_to_matched_subgraph_a = {} for match_name, match in matched_subgraph_pairs.items(): (node_start_a, node_end_a), (node_start_b, node_end_b) = match assert node_start_b is node_end_b, \ "Shadowing subgraphs of B with multiple nodes is not yet handled." node_b_to_matched_subgraph_a[node_end_b] = (node_start_a, node_end_a) for node_b in gm_b.graph.nodes: if node_b.op == 'output': graph_c.output(map_arg(node_b.args[0], load_arg)) continue if node_b.op == 'call_module' and is_activation_post_process( modules[node_b.target]): # remove activation post process node env_c[node_b.name] = env_c[node_b.args[0].name] # type: ignore elif node_b in node_b_to_matched_subgraph_a: node_start_a, node_end_a = node_b_to_matched_subgraph_a[node_b] if False: print('b') print_node(node_b) print('a') print_node(node_start_a) print_node(node_end_a) # ensure env_c is populated with base node env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) node_c = env_c[node_b.name] # after this point, # # node_a is the original node from graph_a, with parent module gm_a # node_b is the original node from graph_b, with parent module gm_b # node_c is the copy of node_b in graph_c # # subgraph so far: # # prev_node_c -> node_c # cast dtype from the dtype of node_c's input to the dtype of # node_a's input (dequant, etc) dtype_cast_node = _insert_dtype_cast_after_node( node_start_a, node_c, node_c.args[0], gm_a, gm_b, graph_c, node_b.name + '_dtype_cast_') env_c[dtype_cast_node.name] = dtype_cast_node # subgraph so far: # # dtype_cast_node # / # prev_node_c -> node_c # hook up the new mod_a copy to be in the graph, receiving the # same inputs as mod_b does, with dtype cast to match a node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( env_c[dtype_cast_node.name], node_start_a, node_end_a, gm_a, gm_b, node_c.name + '_shadow_copy_') env_c[node_a_shadows_c.name] = node_a_shadows_c # subgraph so far: # # dtype_cast_node --> subgraph_a_copy(args/kwargs not shown) # / # prev_node_c -> node_c # hook up a logger to the mod_b copy env_c[node_b.name] = _insert_logger_after_node( env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', name_b) # subgraph so far: # # dtype_cast_node --> subgraph_a_copy # / # prev_node_c -> node_c --> logger_c # hook up a logger to the mod_a copy # Note: we pass node_b.name to this logger, for easy matching later env_c[node_a_shadows_c.name] = _insert_logger_after_node( env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_', name_a, node_b.name) # subgraph so far: # # dtype_cast_node --> subgraph_a_copy --> logger_a # / # prev_node_c -> node_c --> logger_c else: env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) gm_c = GraphModule(gm_b, graph_c) return gm_c
def remove_observers_add_loggers( gm: GraphModule, node_to_instrument_inputs_to_ref_node_name: Dict[Node, str], node_to_instrument_outputs_to_ref_node_name: Dict[Node, str], logger_cls: Callable, model_name: str, ) -> GraphModule: """ Takes the graph of gm, removes all observers, adds loggers to the output of each node in nodes_to_instrument. Returns a GraphModule with the new graph. """ new_graph = Graph() env: Dict[str, Any] = {} modules = dict(gm.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in gm.graph.nodes: if node.op == 'output': new_graph.output(map_arg(node.args[0], load_arg)) continue if node.op == 'call_module' and is_activation_post_process(modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] elif ( (node in node_to_instrument_inputs_to_ref_node_name) or (node in node_to_instrument_outputs_to_ref_node_name) ): if node in node_to_instrument_inputs_to_ref_node_name: ref_name = node_to_instrument_inputs_to_ref_node_name[node] # Ops such add and mul are special because either # one or two of the first two arguments can be tensors, # and if one argument is a tensor it can be first or # second (x + 1 versus 1 + x). arg_indices_to_log = get_arg_indices_of_inputs_to_log(node) for node_arg_idx in arg_indices_to_log: node_arg = node.args[node_arg_idx] if type(node_arg) == Node: # create a single input logger prev_node = env[node_arg.name] env[node_arg.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0, index_of_arg=node_arg_idx) elif type(node_arg) == torch.fx.immutable_collections.immutable_list: # create N input loggers, one for each node for arg_idx, arg in enumerate(node_arg): prev_node = env[arg.name] env[prev_node.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx, index_of_arg=node_arg_idx) else: pass # ensure env is populated with base node # Note: runs for both inputs and outputs env[node.name] = new_graph.node_copy(node, load_arg) if node in node_to_instrument_outputs_to_ref_node_name: ref_name = node_to_instrument_outputs_to_ref_node_name[node] # add the logger after the base node env[node.name] = _insert_logger_after_node( env[node.name], gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0, index_of_arg=0) else: env[node.name] = new_graph.node_copy(node, load_arg) new_gm = GraphModule(gm, new_graph) return new_gm
def _prepare(model: GraphModule, qconfig_dict: Any, node_name_to_scope: Dict[str, Tuple[str, type]], prepare_custom_config_dict: Optional[Dict[str, Any]], is_standalone_module: bool) -> ObservedGraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. How the standalone module is observed is specified by `input_quantized_idxs` and `output_quantized_idxs` in the prepare_custom_config for the standalone module Args: node_name_to_scope: mapping from node name to the scope of the module which contains the node. The scope is a tuple of fully qualified path of the module and the type of the module Returns: model(GraphModule): prepared standalone module attributes: _standalone_module_input_quantized_idxs(List[Int]): a list of indexes for the graph input that is expected to be quantized, same as input_quantized_idxs configuration provided for the standalone module _standalone_module_output_quantized_idxs(List[Int]): a list of indexs for the graph output that is quantized same as input_quantized_idxs configuration provided for the standalone module """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} additional_quant_patterns = \ prepare_custom_config_dict.get("additional_quant_pattern", {}) # mapping from a tuple of nodes in reverse order to uninitialized # QuantizeHandler subclass. For example, # { # # match a single node # (<class 'torch.nn.modules.conv.Conv3d'>: # <class 'torch.quantization.fx.quantize.ConvRelu'>), # # match multiple nodes in reverse order # ((<function relu at 0x7f766a7360d0>, <built-in function add>): # <class 'torch.quantization.fx.quantize.Add'>), # } patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict( get_default_quant_patterns(), additional_quant_patterns) convert_dict_to_ordered_dict(qconfig_dict) flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: additional_qat_module_mapping = prepare_custom_config_dict.get( "additional_qat_module_mapping", {}) qat_swap_modules(model, additional_qat_module_mapping) # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } modules = dict(model.named_modules()) # fill qconfig_map, a map from node name to qconfig, used in find_matches qconfig_map = generate_qconfig_map(model, modules, model.graph, qconfig_dict, node_name_to_scope) # match the patterns that will get quantized standalone_module_name_configs = prepare_custom_config_dict.get( "standalone_module_name", []) standalone_module_class_configs = prepare_custom_config_dict.get( "standalone_module_class", []) standalone_module_names = [ config[0] for config in standalone_module_name_configs ] standalone_module_classes = [ config[0] for config in standalone_module_class_configs ] custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") matches = find_matches(model.graph, modules, patterns, qconfig_map, standalone_module_names, standalone_module_classes, custom_module_classes) input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) run_prepare_fx_on_standalone_modules(model, modules, matches, prepare_custom_config_dict) result_node = insert_observers_for_model(model, modules, matches, qconfig_map, model.graph, prepare_custom_config_dict, input_quantized_idxs, output_quantized_idxs) save_state(model, qconfig_map, node_name_to_scope, patterns, prepare_custom_config_dict) preserved_attributes = set( prepare_custom_config_dict.get("preserved_attributes", [])) model = ObservedGraphModule(model, model.graph, preserved_attributes) if is_standalone_module: assert result_node is not None assert isinstance(result_node.args[0], Node), \ "standalone module only supports returning simple value currently"\ "(not tuple, dict etc.)" # these inputs are observed in parent # converting List[int] to Tensor since module attribute is # Union[Tensor, Module] model._standalone_module_input_quantized_idxs = \ torch.tensor(input_quantized_idxs) model._standalone_module_output_quantized_idxs = torch.tensor( output_quantized_idxs) return model
def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant, is_standalone_module): """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. When we are preparing a standalone module: input of the module is observed in parent module, output of the module is observed in the standalone module. Returns: model(GraphModule): prepared standalone module with following attributes: _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that needs to be observed in parent module _output_is_observed(Bool): a boolean variable indicate whether the output of the custom module is observed or not """ if not inplace: model = copy.deepcopy(model) self.is_dynamic_quant = is_dynamic_quant if self.is_dynamic_quant: self.patterns = get_dynamic_quant_patterns() else: self.patterns = get_quant_patterns() flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: self._qat_swap_modules(model) self.modules = dict(model.named_modules()) convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph, qconfig_dict) # match the patterns that will get quantized standalone_module_names = qconfig_dict.get('standalone_module_name', None) matches = self._find_matches(model.graph, self.modules, self.patterns, standalone_module_names) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuant object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed_node_names_set = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) # indexes for the inputs that needs to be observed standalone_module_observed_input_idxs = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) get_new_observer_name = get_new_attr_name_with_prefix( 'activation_post_process_') for node in model.graph.nodes: if node.name in observed_node_names_set: continue prefix = node.name + '_activation_post_process_' root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) if qconfig is None: continue def insert_observer(node, observer, device): get_new_observer_name = get_new_attr_name_with_prefix( prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) self.activation_post_process_map[node.name] = observer env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node), ), {}) observed_node_names_set.add(node.name) if device: getattr(model, observer_name).to(device) if isinstance(obj, CustomModuleQuantizeHandler): custom_module = self.modules[node.target] observed_custom_module_class = \ get_observed_custom_module_class(type(custom_module)) observed_custom_module = \ observed_custom_module_class.from_float(custom_module) mark_observed_custom_module(observed_custom_module, type(custom_module)) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_custom_module) # index for input of custom module that needs to be observed in parent standalone_module_input_idxs = None if isinstance(obj, StandaloneModuleQuantizeHandler): # observe standalone module standalone_module = self.modules[node.target] traced_standalone_module = symbolic_trace( standalone_module) if self.is_dynamic_quant: prepare = torch.quantization.quantize_fx._prepare_dynamic_standalone_module_fx else: prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx observed_standalone_module = prepare( traced_standalone_module, {'': qconfig}) observed_standalone_module.qconfig = qconfig standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs observed_standalone_module = mark_observed_standalone_module( observed_standalone_module) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_standalone_module) self.modules[node.target] = observed_standalone_module # don't need to insert observer for output in dynamic quantization if self.is_dynamic_quant: continue # inserting observers for output of observed module, or mark the output # as observed if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed_node_names_set elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed_node_names_set.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: if node.args[0].name in observed_node_names_set: observed_node_names_set.add(node.name) elif isinstance(obj, StandaloneModuleQuantizeHandler): assert node.op == 'call_module' output_is_observed = self.modules[ node.target]._output_is_observed if output_is_observed: observed_node_names_set.add(node.name) elif qconfig is not None and obj.all_nodes: # observer for outputs new_observer = qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) insert_observer(node, new_observer, device) # insert observer for input of standalone module if standalone_module_input_idxs is not None: for idx in standalone_module_input_idxs: if node.args[idx].name not in observed_node_names_set: new_observer = qconfig.activation() device = assert_and_get_unique_device(model) insert_observer(node.args[idx], new_observer, device) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.name not in observed_node_names_set and node.name in quants: if is_standalone_module and node.name in graph_inputs: # we'll insert observer for input of standalone module # in parent graph standalone_module_observed_input_idxs.append( graph_inputs.index(node.name)) continue get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) _, qconfig, is_weight = quants[node.name] if qconfig is not None: # TODO: use insert_observer new_observer = \ qconfig.weight() if is_weight else qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) if device: new_observer.to(device) self.activation_post_process_map[node.name] = new_observer setattr(model, observer_name, self.activation_post_process_map[node.name]) env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node), ), {}) observed_node_names_set.add(node.name) observed_graph.output(load_arg(model.graph.result)) model = GraphModule(model, observed_graph) self.save_state(model) if is_standalone_module: assert isinstance(model.graph.result, Node), \ 'standalone module returning dict is not yet supported' # indicator for whether output is observed or not. # This used for correctly quantize standalone modules output_is_observed = model.graph.result.name in observed_node_names_set model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs model._output_is_observed = output_is_observed return model
def scale_weight_functional( op_node: Node, model: GraphModule, modules: Dict[str, nn.Module], equalization_scale: torch.Tensor, next_equalization_scale: Optional[torch.Tensor], ) -> None: """ Scales the weight value for functional layers """ # From the given op_node, the path looks like: # get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node # So we want to trace back from the op_node to get the equalization observer # node, then the quantization observer node, and then finally the weight # node which contains the weight values. # Get the equalization observer node weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules) if weight_eq_obs_node is None: return # Get the quantization observer node weight_quant_obs_node = weight_eq_obs_node.args[0] if weight_quant_obs_node is None: return assert(isinstance(weight_quant_obs_node, Node) and isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)) # Get the get_attr(weight) node weight_node = weight_quant_obs_node.args[0] if weight_node is None: return assert(isinstance(weight_node, Node) and weight_node.op == 'get_attr') weight_parent_name, weight_name = _parent_name(weight_node.target) weight = getattr(modules[weight_parent_name], weight_name) # Scale the weights for input-weight equalization # If the following layer needs to be equalized then we will multiply its scale scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale)) if next_equalization_scale is None: setattr(modules[weight_parent_name], weight_name, scaled_weight) return # Multiply the weights row wise by the next equalization scale new_shape = [1] * weight.ndim new_shape[0] = weight.size(0) scaled_weight = torch.mul(scaled_weight, next_equalization_scale.view(new_shape)) setattr(modules[weight_parent_name], weight_name, scaled_weight) assert(torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)) # Multiply the bias element wise by the next equalization scale bias_node = None for node, _ in op_node.users.items(): # Find the node containing the weight values if node.op == 'get_attr' and 'bias' in node.name: bias_node = node break if bias_node is None: return bias_parent_name, bias_name = _parent_name(bias_node.target) bias = getattr(modules[bias_parent_name], bias_name) scaled_bias = torch.mul(bias, next_equalization_scale) setattr(modules[bias_parent_name], bias_name, scaled_bias)
def _prepare_fx( model: torch.nn.Module, qconfig_dict: Any, prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, is_standalone_module: bool = False, is_qat: bool = False, ) -> ObservedGraphModule: r""" Internal helper function for prepare_fx Args: `model`, `qconfig_dict`, `prepare_custom_config_dict`, `equalization_qonfig_dict`: see docs for :func:`~torch.ao.quantization.prepare_fx` `is_standalone_module`: a boolean flag indicates whether we are quantizing a standalone module or not, a standalone module is a submodule of the parent module that is not inlined in the forward graph of the parent module, the way we quantize standalone module is described in: :func:`~torch.ao.quantization._prepare_standalone_module_fx` """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} if equalization_qconfig_dict is None: equalization_qconfig_dict = {} check_is_valid_qconfig_dict(qconfig_dict) check_is_valid_prepare_custom_config_dict(prepare_custom_config_dict) check_is_valid_qconfig_dict(equalization_qconfig_dict) skipped_module_names = prepare_custom_config_dict.get( "non_traceable_module_name", []) skipped_module_classes = prepare_custom_config_dict.get( "non_traceable_module_class", []) # swap FloatFunctional with FXFloatFunctional _swap_ff_with_fxff(model) # symbolically trace the model if not is_standalone_module: # standalone module and custom module config are applied in top level module standalone_module_name_configs = prepare_custom_config_dict.get( "standalone_module_name", []) skipped_module_names += [ config[0] for config in standalone_module_name_configs ] standalone_module_class_configs = prepare_custom_config_dict.get( "standalone_module_class", []) skipped_module_classes += [ config[0] for config in standalone_module_class_configs ] float_custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") skipped_module_classes += float_custom_module_classes preserved_attributes = prepare_custom_config_dict.get( "preserved_attributes", []) tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) graph_module = GraphModule(model, tracer.trace(model)) for attr_name in preserved_attributes: setattr(graph_module, attr_name, getattr(model, attr_name)) graph_module = _fuse_fx(graph_module, prepare_custom_config_dict, backend_config_dict) prepared = prepare( graph_module, qconfig_dict, tracer.node_name_to_scope, prepare_custom_config_dict=prepare_custom_config_dict, equalization_qconfig_dict=equalization_qconfig_dict, backend_config_dict=backend_config_dict, is_standalone_module=is_standalone_module, is_qat=is_qat, ) for attr_name in preserved_attributes: setattr(prepared, attr_name, getattr(model, attr_name)) return prepared
def _convert_do_not_use( model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False, _remove_qconfig_flag: bool = True) -> QuantizedGraphModule: """ We will convert an observed model (a module with observer calls) to a reference quantized model, the rule is simple: 1. for each observer module call in the graph, we'll convert it to calls to quantize and dequantize functions based on the observer instance 2. for weighted operations like linear/conv, we need to convert them to reference quantized module, this requires us to know whether the dtype configured for the weight is supported in the backend, this is done in prepare step and the result is stored in observed_node_names, we can decide whether we need to swap the module based on this set standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model) qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment] assert is_reference, "_convert_do_not_use only supports reference option" # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } # We use remove_duplicate=False here because torch.cat uses # the same activation_post_process module instance but different names modules = dict(model.named_modules(remove_duplicate=False)) custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") matches = find_matches( model.graph, modules, patterns, qconfig_map, custom_module_classes=custom_module_classes) if model._equalization_qconfig_map is not None: # If we want to do equalization then do the following: # Calculate the equalization scale, update the observers with the scaled # inputs, and scale the weight weight_eq_obs_dict = update_obs_for_equalization(model, modules) convert_eq_obs(model, modules, weight_eq_obs_dict) graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def replace_observer_with_quantize_dequantize_node(graph: Graph, node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] root_module = modules[""] if observer_module.dtype == torch.float32: # remove the node for now # TODO: support dynamic quant with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) elif observer_module.dtype in [torch.quint8, torch.qint8, torch.float16]: node_type, quantize_op, qparams = get_quantize_node_info(observer_module) # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value(root_module, graph, key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) node.replace_all_uses_with(dequantized_node) graph.erase_node(node) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) for node in list(model.graph.nodes): if node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: # Inputs are assumed to be quantized if the user specifid the # input_quantized_idxs override. # Note: we don't need to do anything for this, it affects prepare # step in terms of whether to insert observer for input or not continue elif node.op == "output": cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: # Result are kept quantized if the user specified the # output_quantized_idxs override. # Remove the dequantize operator in the end maybe_dequantize_node = node.args[0] if isinstance(maybe_dequantize_node, Node) and \ maybe_dequantize_node.op == "call_method" and \ maybe_dequantize_node.target == "dequantize": quantized_node = maybe_dequantize_node.args[0] maybe_dequantize_node.replace_all_uses_with(quantized_node) model.graph.erase_node(maybe_dequantize_node) elif node.op == "call_module": if is_activation_post_process(modules[node.target]): replace_observer_with_quantize_dequantize_node(model.graph, node, modules) elif type(modules[node.target]) in set( WEIGHTED_MODULE_CLASSES).union(QAT_MODULE_CLASSES).union(FUSED_MODULE_CLASSES): # TODO: refactor this part to a function original_module = modules[node.target] qconfig = original_module.qconfig is_observed = node.name in observed_node_names is_weight_quantized = weight_is_statically_quantized(qconfig) # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized if qconfig is None or not is_observed or not is_weight_quantized: continue float_module = original_module fused_module = None if isinstance( original_module, QAT_MODULE_CLASSES): # case 1. converting qat module to # a float module, we need to attch # weight fake_quant to the module, # weight fake_quant is assumed to be run during # QAT so we don't need to run it again here float_module = original_module.to_float() # type: ignore[operator] # change qat conv to conv parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, float_module) if isinstance(float_module, torch.nn.intrinsic._FusedModule): fused_module = float_module float_module = fused_module[0] weight_post_process = original_module.weight_fake_quant else: # case 2. converting a float module/fused float module # to float module, we need to attach # weight observer to the conv module and run it # with conv weight if isinstance(original_module, torch.nn.intrinsic._FusedModule): fused_module = original_module float_module = fused_module[0] # type: ignore[index] assert qconfig is not None weight_post_process = qconfig.weight() # run weight observer weight_post_process(float_module.weight) # type: ignore[operator] weight_qparams = get_qparam_dict(weight_post_process) ref_qmodule_cls = get_static_quant_module_class(type(float_module), is_reference=True) ref_qmodule = ref_qmodule_cls.from_float(float_module, weight_qparams) if fused_module is not None: fused_module[0] = ref_qmodule else: parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, ref_qmodule) # removes qconfig and activation_post_process modules if _remove_qconfig_flag: _remove_qconfig(model) preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", [])) model = QuantizedGraphModule(model, model.graph, preserved_attributes) return model
def _convert(self, model: GraphModule, debug: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False) -> GraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module which accepts float input and produces float output. """ if convert_custom_config_dict is None: convert_custom_config_dict = {} self.restore_state(model) # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops self._run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() self.modules = dict(model.named_modules()) custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") assert self.patterns is not None matches = self._find_matches( model.graph, self.modules, self.patterns, custom_module_classes=custom_module_classes) quants = self._find_quants(model.graph, matches) self.quantized_graph = Graph() env: Dict[Any, Any] = {} quant_env: Dict[Any, Any] = {} graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def load_non_quantized(n): if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find ' + \ 'node:' + n.name + \ ' in quantized or non quantized environment, env: ' + \ str(env) + ' quant_env:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n): if n.name not in quant_env: assert n.name in env, \ 'trying to load quantized node but did not find node:' + \ n.name + ' in float environment:' + str(env) assert n.name in quants, \ 'did not find quant object for node:' + n.name quant = quants[n.name][0] quant_env[n.name] = quant.convert(self, env[n.name]) return quant_env[n.name] def load_x(n): assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized): """ Input: quantized, which can be None, list, boolean or tuple - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is None, then we'll load the node as long as it exists Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg_or_args): if quantized is None: return map_arg(arg_or_args, load_x) if isinstance(quantized, bool): return map_arg( arg_or_args, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def is_quantized(node): if isinstance(node, Node): assert node.name in env or node.name in quant_env, \ 'Expecting node to be in the environment' # there might be nodes appearing in both environemnts, but # quant_env will take precedence if node.name in quant_env: return True elif node.name in env: return False elif isinstance(node, list): quantized = map(is_quantized, node) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") def is_output_quantized(node) -> bool: """ Check if output node is quantized or not """ assert self.modules is not None # by default the output is expected to be quantized quantized = True # Need to get correct quantized/non-quantized state for the output # of CopyNode if type(obj) in [CopyNode, FixedQParamsOpQuantizeHandler]: assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = is_quantized(node.args[0]) if not activation_is_statically_quantized(qconfig) or \ not input_output_observed(obj): quantized = False return quantized def insert_quantize_node(node): """ Given a activation_post_process module call node, insert a quantize node""" assert self.modules is not None observer_module = self.modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float16: # activations are not quantized for # fp16 dynamic quantization # copy the activaiton_post_process node here # since we may need it when we insert prepack # op for weight of linear, this will be removed # later in a separate pass env[node.name] = self.quantized_graph.node_copy( node, load_non_quantized) elif prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the # activation_post_process quant_env[node.name] = quant_env[prev_node.name] else: # replace activation post process with quantization ops root_module = self.modules[""] quant_env[node.name] = quantize_node( root_module, self.quantized_graph, load_non_quantized(node.args[0]), observer_module) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = convert_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = convert_custom_config_dict.get( "output_quantized_idxs", []) for node in model.graph.nodes: if node.op == 'output': cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: # Result are kept quantized if the user specified the # output_quantized_idxs override. graph_output = map_arg(node.args[0], load_x) else: graph_output = map_arg(node.args[0], load_non_quantized) self.quantized_graph.output(graph_output) continue root_node, matched, matched_pattern, obj, qconfig = \ matches.get(node.name, (None, None, None, None, None)) if root_node is node: if qconfig is None: result = self.quantized_graph.node_copy( node, load_non_quantized) quantized = False else: assert obj is not None is_standalone_module_node = is_observed_standalone_module_node( node, self.modules) result = obj.convert( self, node, load_arg, debug=debug, convert_custom_config_dict=convert_custom_config_dict) if is_standalone_module_node: quantized = False else: quantized = is_output_quantized(node) if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module' and \ is_activation_post_process(self.modules[node.target]): insert_quantize_node(node) elif node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: quant_env[node.name] = \ self.quantized_graph.node_copy(node, load_non_quantized) else: env[node.name] = \ self.quantized_graph.node_copy(node, load_non_quantized) else: # copy quantized or non-quantized node env[node.name] = \ self.quantized_graph.node_copy(node, load_non_quantized) # remove activation post process act_post_process_removed_graph = Graph() env = {} def load_arg(a): # type: ignore return map_arg(a, lambda node: env[node.name]) for node in self.quantized_graph.nodes: if node.op == 'output': act_post_process_removed_graph.output( map_arg(node.args[0], load_arg)) continue if node.op == 'call_module' and \ is_activation_post_process(self.modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] else: env[node.name] = act_post_process_removed_graph.node_copy( node, load_arg) # removes qconfig and activation_post_process modules _remove_qconfig(model) model = GraphModule(model, act_post_process_removed_graph) return model
tanh_1 = torch.tanh(cat_1); cat_1 = None neg_1 = torch.neg(tanh_1); tanh_1 = None return neg_1 """ # Create a graph independently of symbolic tracing graph = Graph() # Create raw Nodes raw1 = graph.placeholder("x") raw2 = graph.placeholder("y") # Initialize Proxies using the raw Nodes y = Proxy(raw1) z = Proxy(raw2) # Create other operations using the Proxies `y` and `z` a = torch.cat([y, z]) b = torch.tanh(a) c = torch.neg(b) # Create a new output Node and add it to the Graph. By doing this, the # Graph will contain all the Nodes we just created (since they're all # linked to the output Node) graph.output(c.node) # Wrap our created Graph in a GraphModule to get a final, runnable # `nn.Module` instance mod = GraphModule(torch.nn.Module(), graph)
def _prepare(self, model: GraphModule, qconfig_dict: Any, prepare_custom_config_dict: Optional[Dict[str, Any]], is_standalone_module: bool) -> GraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. When we are preparing a standalone module: both input and output are observed in prepared standalone module Returns: model(GraphModule): prepared standalone module """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} additional_quant_patterns = \ prepare_custom_config_dict.get("additional_quant_pattern", {}) self.patterns = get_combined_dict(get_default_quant_patterns(), additional_quant_patterns) flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: additional_qat_module_mapping = prepare_custom_config_dict.get( "additional_qat_module_mapping", {}) self._qat_swap_modules(model, additional_qat_module_mapping) self.modules = dict(model.named_modules()) convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph, qconfig_dict) # match the patterns that will get quantized standalone_module_names = prepare_custom_config_dict.get( "standalone_module_name", None) standalone_module_classes = prepare_custom_config_dict.get( "standalone_module_class", None) custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") assert self.patterns is not None matches = self._find_matches(model.graph, self.modules, self.patterns, standalone_module_names, standalone_module_classes, custom_module_classes) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuantizeHandler object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env: Dict[Any, Any] = {} observed_graph = Graph() observed_node_names_set: Set[str] = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) # indexes for the inputs that needs to be observed standalone_module_observed_input_idxs: List[int] = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) get_new_observer_name = get_new_attr_name_with_prefix( 'activation_post_process_') model_device = assert_and_get_unique_device(model) result_node: Optional[Node] = None for node in model.graph.nodes: if node.op == 'output': observed_graph.output(load_arg(node.args[0])) result_node = node continue if node.name in observed_node_names_set: continue root_node, matched_nodes, pattern, obj, qconfig = matches.get( node.name, (None, None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) # index for input of custom module that needs to be observed in # parent if qconfig is not None: assert obj is not None insert_observer_for_special_module( obj, self.modules, prepare_custom_config_dict, qconfig, node) insert_observer_for_output_of_the_node( node, obj, qconfig, self.modules, model, pattern, model_device, self.activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set, matched_nodes) else: env[node.name] = observed_graph.node_copy(node, load_arg) insert_observer_for_input_arg_of_observed_node( node, observed_node_names_set, quants, model_device, model, self.activation_post_process_map, env, observed_graph, load_arg) model = GraphModule(model, observed_graph) self.save_state(model) model = mark_observed_module(model) return model
def save_state(self, observed: GraphModule) -> None: observed._activation_post_process_map = \ self.activation_post_process_map # type: ignore observed._patterns = self.patterns # type: ignore observed._qconfig_map = self.qconfig_map # type: ignore
def call(self, graph_module: GraphModule) -> PassResult: """ Return a new copy of torch.fx.GraphModule with CSE applied to the input graph Example usage: from torch.fx.experimental.proxy_tensor import make_fx def f(a): b = a * a c = a * a return b+c p = CSEPass() traced_graph = make_fx(f)(torch.tensor(1)) print(traced_graph) result = p(traced_graph) print(result.graph_module) """ def get_aten_target(node): if hasattr(node.target, 'overloadpacket'): return node.target.overloadpacket return node.target modified = False new_graph = Graph() env: Dict[Node, Node] = { } # map from node in the old graph to node in the new graph hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token for n in graph_module.graph.nodes: # The placeholder, output, and get_attr nodes are copied to the new grpah without change # do not CSE away random operations if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target( n) in self.banned_ops: new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' # substitute args and kwargs memebrs to their mapping in env if exists # specs can be used to reconstruct nested list/dictionaries def substitute(arg_list): arg_list, spec = tree_flatten(arg_list) for i in range(len(arg_list)): v = arg_list[i] if isinstance(v, Node) and v in env: arg_list[i] = env[v] return tuple(arg_list), spec args, args_spec = substitute(n.args) kwargs, kwargs_spec = substitute(n.kwargs) # each token corresponds to a unique node # nodes with the same token can be substituted token = { "target": n.target, "args": args, "args_spec": args_spec, "kwargs": kwargs, "kwargs_spec": kwargs_spec } # hash substituted args to a number, do not hash specs because specs are not hashable hash_arg = hash((args, kwargs)) hash_val = (n.target, hash_arg) # check if a node has a substitute and can be eliminated hash_val_in_hash_env = hash_val in hash_env if hash_val_in_hash_env and token_map[hash_val] == token: modified = True # substition happens and the graph is modified env[n] = hash_env[hash_val] continue new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node if not hash_val_in_hash_env: hash_env[hash_val] = new_node token_map[hash_val] = token csed_gm = GraphModule(graph_module, new_graph) return PassResult(csed_gm, modified)
def _prepare(self, model, qconfig_dict, prepare_custom_config_dict, is_standalone_module): """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. When we are preparing a standalone module: input of the module is observed in parent module, output of the module is observed in the standalone module. Returns: model(GraphModule): prepared standalone module with following attributes: _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that needs to be observed in parent module _output_is_observed(Bool): a boolean variable indicate whether the output of the custom module is observed or not """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} additional_quant_patterns = prepare_custom_config_dict.get("additional_quant_pattern", {}) self.patterns = get_combined_dict(get_default_quant_patterns(), additional_quant_patterns) flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: additional_qat_module_mapping = prepare_custom_config_dict.get("additioanl_qat_module_mapping", {}) self._qat_swap_modules(model, additional_qat_module_mapping) self.modules = dict(model.named_modules()) convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph, qconfig_dict) # match the patterns that will get quantized standalone_module_names = prepare_custom_config_dict.get("standalone_module_name", None) standalone_module_classes = prepare_custom_config_dict.get("standalone_module_class", None) custom_module_classes = get_custom_module_class_keys(prepare_custom_config_dict, "float_to_observed_custom_module_class") matches = self._find_matches( model.graph, self.modules, self.patterns, standalone_module_names, standalone_module_classes, custom_module_classes) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuantizeHandler object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed_node_names_set = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) # indexes for the inputs that needs to be observed standalone_module_observed_input_idxs = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_') model_device = assert_and_get_unique_device(model) def insert_observer(node, observer): """Insert observer for node by modifying the observed_graph and attach observer module to the model Args: node: Node observer: observer/fake_quantize module instance """ # respect device affinity when adding observers if model_device: observer.to(model_device) # add observer module as attribute prefix = node.name + '_activation_post_process_' get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) # put observer instance activation_post_process map self.activation_post_process_map[node.name] = observer # insert observer call env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) result_node : Optional[Node] = None for node in model.graph.nodes: if node.op == 'output': observed_graph.output(load_arg(node.args[0])) result_node = node continue if node.name in observed_node_names_set: continue root_node, matched_nodes, pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) # index for input of custom module that needs to be observed in parent standalone_module_input_idxs = None if qconfig is not None: if isinstance(obj, CustomModuleQuantizeHandler): custom_module = self.modules[node.target] custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) observed_custom_module_class = \ get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig) observed_custom_module = \ observed_custom_module_class.from_float(custom_module) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_custom_module) elif isinstance(obj, StandaloneModuleQuantizeHandler): # observe standalone module standalone_module = self.modules[node.target] prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx observed_standalone_module = prepare(standalone_module, {'': qconfig}) observed_standalone_module.qconfig = qconfig standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs observed_standalone_module = mark_observed_standalone_module(observed_standalone_module) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_standalone_module) self.modules[node.target] = observed_standalone_module # don't need to insert observer for output if activation does not # need to be statically quantized if activation_is_statically_quantized(qconfig): if isinstance(obj, FixedQParamsOpQuantizeHandler) and model.training: # we only insert fake quantize module in qat activation_post_process_ctr = \ get_default_output_activation_post_process_map().get(pattern, None) assert activation_post_process_ctr is not None, \ "activation_post_process constructor not provided for " + \ "pattern:" + str(pattern) insert_observer(node, activation_post_process_ctr()) elif (isinstance(obj, FixedQParamsOpQuantizeHandler) and not model.training) or isinstance(obj, CopyNode): # inserting observers for output of observed module, or mark the output # as observed assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed_node_names_set elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed_node_names_set.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and obj.num_node_args == 1: input_node = matched_nodes[-1] # first node in the sequence def input_is_observed(arg): return isinstance(arg, Node) and arg.name in observed_node_names_set # This is checking if one of the argument of add/mul # is an observed node # If both of the inputs are number, # we will not consider the output to be observed if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]): observed_node_names_set.add(node.name) elif isinstance(obj, StandaloneModuleQuantizeHandler): assert node.op == 'call_module' output_is_observed = self.modules[node.target]._output_is_observed if output_is_observed: observed_node_names_set.add(node.name) elif obj.all_node_args: # observer for outputs new_observer = qconfig.activation() insert_observer(node, new_observer) # insert observer for input of standalone module if standalone_module_input_idxs is not None: for idx in standalone_module_input_idxs: if node.args[idx].name not in observed_node_names_set: new_observer = qconfig.activation() insert_observer(node.args[idx], new_observer) else: env[node.name] = observed_graph.node_copy(node, load_arg) # insert observer for output of the node if node.name not in observed_node_names_set and node.name in quants: if is_standalone_module and node.name in graph_inputs: # we'll insert observer for input of standalone module # in parent graph standalone_module_observed_input_idxs.append(graph_inputs.index(node.name)) continue _, activation_post_process_ctr = quants[node.name] if activation_post_process_ctr is not None: insert_observer(node, activation_post_process_ctr()) model = GraphModule(model, observed_graph) self.save_state(model) model = mark_observed_module(model) if is_standalone_module: assert result_node is not None assert isinstance(result_node.args[0], Node), \ 'standalone module returning dict is not yet supported' # indicator for whether output is observed or not. # This used for correctly quantize standalone modules output_is_observed = result_node.args[0].name in observed_node_names_set model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs model._output_is_observed = output_is_observed return model
def convert(model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False, _remove_qconfig_flag: bool = True) -> QuantizedGraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} patterns, node_name_to_scope, prepare_custom_config_dict = restore_state( model) qconfig_map: Dict[ str, QConfigAny] = model._qconfig_map # type: ignore[assignment] # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } # We use remove_duplicate=False here because torch.cat uses # the same activation_post_process module instance but different names modules = dict(model.named_modules(remove_duplicate=False)) custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") matches = find_matches(model.graph, modules, patterns, qconfig_map, custom_module_classes=custom_module_classes) quantized_graph = Graph() env: Dict[str, Tuple[Node, Optional[torch.dtype]]] = {} graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def load_non_quantized(n: Node) -> Node: assert n.name in env, \ 'trying to load float node but did not find ' + \ 'node:' + n.name + \ ' in env: ' + \ str(env) quantized_node, dtype = env[n.name] if dtype and dtype != torch.float: env[n.name] = Proxy(quantized_node).dequantize().node, torch.float return env[n.name][0] def load_quantized(n: Node) -> Node: assert n.name in env, \ 'trying to load quantized node but did not find node:' + \ n.name + ' in environment:' + str(env) quantized_node, dtype = env[n.name] assert dtype in [torch.quint8, torch.qint8, torch.float16], \ f'Expecting node {quantized_node} to be quantized but got dtype: {dtype}' return quantized_node def load_x(n: Node) -> Node: assert n.name in env, \ 'node ' + n.name + ' does not exist in environment' return env[n.name][0] def load_arg( quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, list, boolean or tuple - if quantized is None, then we'll load the node as long as it exists - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False) - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, bool)), type(quantized) if isinstance(quantized, (tuple, list)) and len(quantized) == 0: # empty tuple or list means nothing is quantized quantized = False def load_arg_impl(arg_or_args): # we'll update the format of `quantized` # to better match arg_or_args updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized if isinstance(quantized, (tuple, list)) and \ len(quantized) == 1 and isinstance(arg_or_args, Node): # when argument is one Node instead of tuple, we just need to check # 0 is in the quantized list updated_quantized = 0 in quantized if updated_quantized is None: return map_arg(arg_or_args, load_x) if isinstance(updated_quantized, bool): return map_arg( arg_or_args, load_quantized if updated_quantized else load_non_quantized) elif isinstance(updated_quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in updated_quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def node_arg_is_quantized(node_arg: Any) -> bool: if isinstance(node_arg, Node): assert node_arg.name in env, \ 'Expecting node_arg to be in the environment' if node_arg.name in env: _, dtype = env[node_arg.name] return dtype != torch.float else: return False elif isinstance(node_arg, list): quantized = map(node_arg_is_quantized, node_arg) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") else: return False def is_output_quantized(node: Node, obj: QuantizeHandler, qconfig: QConfigAny, modules: Dict[str, torch.nn.Module]) -> bool: """ Check if output node is quantized or not """ assert modules is not None # by default the output for a quantizable node is expected to be quantized quantized = True # Need to get correct quantized/non-quantized state forn the output # of FixedQParamsQuantizeHandler # TODO: we may want to try to remove the special case here # as well if obj.should_mark_output_quantized_from_input_quantized_status( qconfig): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled' # TODO: need to extend this to consider all relevant args instead of just arg[0] quantized = node_arg_is_quantized(node.args[0]) # the output is unquantized if the node is not a CopyNode # or the activation is not statically quantized if not activation_is_statically_quantized(qconfig) or \ not obj.input_output_observed(): quantized = False if node_return_type_is_int(node): quantized = False return quantized def insert_quantize_node(node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Given a activation_post_process module call node, insert a quantize node""" assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float32: # copy the observer for fp32 dtype env[node.name] = quantized_graph.node_copy( node, load_non_quantized), torch.float elif isinstance(prev_node, Node) and prev_node.name in env: # if previous node is already quantized, we'll just remove the # activation_post_process _, prev_dtype = env[prev_node.name] current_dtype = observer_module.dtype if prev_dtype == current_dtype: env[node.name] = env[prev_node.name] else: root_module = modules[""] assert isinstance(prev_node, Node) observer_dtype: torch.dtype = observer_module.dtype # type: ignore[assignment] env[node.name] = (quantize_node(load_non_quantized(prev_node), observer_module, node, modules, quantized_graph, node_name_to_scope, is_input=True), observer_dtype) else: # replace activation post process with quantization ops root_module = modules[""] assert isinstance(node.args[0], Node) dtype: torch.dtype = observer_module.dtype # type: ignore[assignment] env[node.name] = (quantize_node(load_non_quantized(node.args[0]), observer_module, node, modules, quantized_graph, node_name_to_scope, is_input=True), dtype) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) for node in model.graph.nodes: if node.op == "output": cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: # Result are kept quantized if the user specified the # output_quantized_idxs override. graph_output = map_arg(node.args[0], load_x) else: graph_output = map_arg(node.args[0], load_non_quantized) quantized_graph.output(graph_output) continue root_node, matched, matched_pattern, obj, qconfig = \ matches.get(node.name, (None, None, None, None, None)) if root_node is node: is_observed_standalone_module_node = ( node.op == 'call_module' and is_observed_standalone_module(modules[node.target])) if qconfig is None and not is_observed_standalone_module_node: result = quantized_graph.node_copy(node, load_non_quantized) quantized = False else: assert obj is not None # We will get whether the output is quantized or not before # convert for standalone module and after convert # for non-standalone module, since _standalone_module_output_quantized_idxs # is only available in observed standalone module if is_observed_standalone_module_node: out_quant_idxs = modules[ node. target]._standalone_module_output_quantized_idxs.tolist( ) # type: ignore[operator] # noqa: B950 assert len( out_quant_idxs ) <= 1, "Currently standalone only support one output" quantized = 0 in out_quant_idxs qconfig = qconfig_map[node.name] result = obj.convert( node, qconfig, modules, quantized_graph, node_name_to_scope, load_arg, is_reference=is_reference, convert_custom_config_dict=convert_custom_config_dict) if not is_observed_standalone_module_node: quantized = is_output_quantized(node, obj, qconfig, modules) if quantized: env[node.name] = result, activation_dtype(qconfig) else: env[node.name] = result, torch.float continue elif root_node is not None: if qconfig is None: # This branch is hit if all of these conditions are met: # 1. we are in a fusion pattern of multiple nodes (i.e. add-relu) # 2. the current node is not the "root_node" of the pattern # 3. quantization for this pattern is disabled # # In this case, we need to make sure to populate the env with # intermediate nodes manually, because the QuantizeHandler.convert # function will not be called. result = quantized_graph.node_copy(node, load_non_quantized) env[node.name] = result, torch.float continue # handle activation post process calls if node.op == 'call_module' and \ is_activation_post_process(modules[node.target]): insert_quantize_node(node, modules) elif node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: env[node.name] = \ quantized_graph.node_copy( node, load_non_quantized), torch.quint8 else: env[node.name] = \ quantized_graph.node_copy(node, load_non_quantized), torch.float else: # copy quantized or non-quantized node # get_tensor_info_node like shape works for both # quantized and non-quantized input and output a non-Tensor # (we use None for dtype currently for non-Tensors) if is_get_tensor_info_node(node): env[node.name] = \ quantized_graph.node_copy(node, load_x), None else: env[node.name] = \ quantized_graph.node_copy(node, load_non_quantized), torch.float # remove activation post process act_post_process_removed_graph = Graph() remove_env: Dict[str, Node] = {} def load_arg_remove(a: Argument) -> Argument: return map_arg(a, lambda node: remove_env[node.name]) for node in quantized_graph.nodes: if node.op == 'output': act_post_process_removed_graph.output( map_arg(node.args[0], load_arg_remove)) continue if node.op == 'call_module' and \ is_activation_post_process(modules[node.target]): # remove activation post process node remove_env[node.name] = remove_env[node.args[0].name] else: remove_env[node.name] = act_post_process_removed_graph.node_copy( node, load_arg_remove) # removes qconfig and activation_post_process modules if _remove_qconfig_flag: _remove_qconfig(model) preserved_attributes = set( convert_custom_config_dict.get("preserved_attributes", [])) model = QuantizedGraphModule(model, act_post_process_removed_graph, preserved_attributes) if not is_reference: model = fold_weight(model, node_name_to_scope) return model
def _convert(self, observed, inplace=False, debug=False, is_dynamic_quant=False): assert not inplace, 'inplace convert is not supported yet' self.restore_state(observed) self.is_dynamic_quant = is_dynamic_quant # run weight observers before inserting quant dequant nodes # for dynamic quantization if self.is_dynamic_quant: self._run_weight_observers(observed) # move to cpu since we only have quantized cpu kernels observed.eval().cpu() observed_root = observed.root observed_graph = observed.graph if not inplace: observed_root = copy.deepcopy(observed_root) self.modules = dict(observed_root.named_modules()) matches = self._find_matches(observed.graph, self.modules, self.patterns) quants = self._find_quants(observed.graph, matches) self.quantized_graph = Graph() env = {} quant_env = {} def load_non_quantized(n): if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find node:' + n.name + \ ' in quantized environment:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n): if n.name not in quant_env: assert n.name in env, \ 'trying to load quantized node but did not find node:' + n.name + \ ' in float environment:' + str(env) assert n.name in quants, 'did not find quant object for node:' + n.name quant = quants[n.name][0] quant_env[n.name] = quant.convert(self, env[n.name]) return quant_env[n.name] def load_x(n): assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either of the environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized): """ if quantized is a list, then arg should be a list and the args with corresponding indexes will be quantized if quantized is a boolean, then all args will be quantized/not quantized if quantized is None, then we'll load the node as long as it exists """ assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg): if quantized is None: return map_arg(arg, load_x) if isinstance(quantized, bool): return map_arg(arg, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg, (tuple, list)), arg loaded_arg = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg): if i in quantized: loaded_arg.append(map_arg(a, load_quantized)) else: loaded_arg.append(map_arg(a, load_non_quantized)) return type(arg)(loaded_arg) return load_arg_impl def is_quantized(node): if isinstance(node, Node): assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment' # there might be nodes appearing in both environemnts, but quant_env will take # precedence if node.name in quant_env: return True elif node.name in env: return False elif isinstance(node, list): quantized = map(is_quantized, node) if all(quantized): return True elif not any(quantized): return False else: raise Exception("partially quantized inputs in list not handled yet") for node in observed_graph.nodes: root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is node: result = obj.convert(self, node, load_arg) quantized = True # Need to get correct quantized/non-quantized state for the output of CopyNode if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = is_quantized(node.args[0]) # output of dynamic quantization is not quantized if self.is_dynamic_quant: quantized = False if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module': if node.target.split('.')[-1].startswith('activation_post_process_'): observer_module = self.modules[node.target] prev_node = node.args[0] if prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the activation_post_process quant_env[node.name] = quant_env[prev_node.name] continue # replace activation post process with quantization ops parent_name = '' scale, zero_point = observer_module.calculate_qparams() dtype = observer_module.dtype def is_per_channel(qscheme): return qscheme == torch.per_channel_affine or \ qscheme == torch.per_channel_symmetric if is_per_channel(observer_module.qscheme): ch_axis = int(observer_module.ch_axis) qparams = {'_scale_': scale, '_zero_point_': zero_point, '_axis': ch_axis, '_dtype_': dtype} quantize_op = torch.quantize_per_channel else: scale = float(scale) zero_point = int(zero_point) qparams = {'_scale_': scale, '_zero_point_': zero_point, '_dtype_': dtype} quantize_op = torch.quantize_per_tensor i = 0 def noattr(module, qparams, i): for name in qparams.keys(): if hasattr(module, name + str(i)): return False return True def get_next_i(module, qparams): i = 0 while not noattr(module, qparams, i): i += 1 return i parent_module = self.modules[parent_name] i = get_next_i(parent_module, qparams) inputs = [load_non_quantized(node.args[0])] for key, value in qparams.items(): setattr(parent_module, key + str(i), value) qparam_full_path = key + str(i) if parent_name: qparam_full_path = parent_name + '.' + qparam_full_path inputs.append(self.quantized_graph.create_node('get_param', qparam_full_path)) quant_env[node.name] = self.quantized_graph.create_node('call_function', quantize_op, tuple(inputs), {}) continue # dequantize inputs for the node that are not quantized env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) self.quantized_graph.output(load_non_quantized(observed_graph.result)) to_be_removed = [] for name, _ in observed_root.named_modules(): if name.split('.')[-1].startswith('activation_post_process_'): to_be_removed.append(name) for n in to_be_removed: delattr(observed_root, n) return GraphModule(observed_root, self.quantized_graph)
def add_loggers_to_model( gm: GraphModule, node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]], node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]], logger_cls: Callable, model_name: str, ) -> GraphModule: """ Takes the graph of gm, adds loggers to the output of each node in nodes_to_instrument. Returns a GraphModule with the new graph. """ new_graph = Graph() env: Dict[str, Any] = {} modules = dict(gm.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in gm.graph.nodes: if node.op == 'output': new_graph.output( map_arg(_get_normalized_nth_input(node, gm, 0), load_arg)) continue if ((node in node_to_instrument_inputs_to_ref_node_name) or (node in node_to_instrument_outputs_to_ref_node_name)): fqn = _maybe_get_fqn(node, gm) if node in node_to_instrument_inputs_to_ref_node_name: ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[ node] # Ops such add and mul are special because either # one or two of the first two arguments can be tensors, # and if one argument is a tensor it can be first or # second (x + 1 versus 1 + x). arg_indices_to_log = get_arg_indices_of_inputs_to_log(node) for node_arg_idx in arg_indices_to_log: node_arg = _get_normalized_nth_input( node, gm, node_arg_idx) if type(node_arg) == Node: # create a single input logger prev_node = env[node_arg.name] env[node_arg.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, ref_node_type, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0, index_of_arg=node_arg_idx, fqn=fqn) elif type( node_arg ) == torch.fx.immutable_collections.immutable_list: # create N input loggers, one for each node for arg_idx, arg in enumerate( node_arg ): # type: ignore[var-annotated, arg-type] prev_node = env[arg.name] env[prev_node.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, ref_node_type, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx, index_of_arg=node_arg_idx, fqn=fqn) else: pass # ensure env is populated with base node # Note: runs for both inputs and outputs env[node.name] = new_graph.node_copy(node, load_arg) if node in node_to_instrument_outputs_to_ref_node_name: ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[ node] # add the logger after the base node env[node.name] = _insert_logger_after_node( env[node.name], gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, ref_node_type, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0, index_of_arg=0, fqn=fqn) else: env[node.name] = new_graph.node_copy(node, load_arg) new_gm = GraphModule(gm, new_graph) return new_gm
def convert(model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False, _remove_qconfig_flag: bool = True, convert_qconfig_dict: Dict[str, Any] = None) -> torch.nn.Module: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} patterns, node_name_to_scope, prepare_custom_config_dict, _ = restore_state( model) qconfig_map: Dict[ str, QConfigAny] = model._qconfig_map # type: ignore[assignment] # TODO this should be removed now that gpu support for quantization is being supported. # however in practice, as of 7/22/2021, certain functions that get called by convert expect # only cpu arguments. # As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda', # fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend. if not is_reference: model.cpu() # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } # We use remove_duplicate=False here because torch.cat uses # the same activation_post_process module instance but different names modules = dict(model.named_modules(remove_duplicate=False)) # TODO refactor this code once we update the prepare logic to have additional information on # which graph nodes have been observed and share that with convert to decide which observers to ignore. if convert_qconfig_dict: prepare_qconfig_dict: Dict[str, Dict[ Any, Any]] = model._qconfig_dict # type: ignore[assignment] modules_copy = copy.deepcopy(modules) convert_dict_to_ordered_dict(convert_qconfig_dict) if model._is_qat: additional_qat_module_mapping = prepare_custom_config_dict.get( "additional_qat_module_mapping", {}) convert_qconfig_dict = update_qconfig_for_qat( convert_qconfig_dict, additional_qat_module_mapping) convert_qconfig_dict = update_qconfig_for_fusion( model, convert_qconfig_dict) compare_prepare_convert_qconfig_dict( prepare_qconfig_dict, convert_qconfig_dict) # type: ignore[arg-type] convert_qconfig_map = generate_qconfig_map(model, modules_copy, model.graph, convert_qconfig_dict, node_name_to_scope) # check the convert_qconfig_map generated and ensure that all the values either match what was set in prepare qconfig_map # or are set to None in the convert_qconfig_map. for k, v in qconfig_map.items(): assert k in convert_qconfig_map, 'Expected key {} in convert qconfig_map'.format( k) if convert_qconfig_map[k] is not None: assert qconfig_equals( v, convert_qconfig_map[k] ), 'Expected k {} to have the same value in prepare qconfig_dict \ and convert qconfig_dict, found {} updated to {}.'.format( k, v, convert_qconfig_map[k]) qconfig_map = convert_qconfig_map custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") matches = find_matches(model.graph, modules, patterns, qconfig_map, custom_module_classes=custom_module_classes) if model._equalization_qconfig_map is not None: # If we want to do equalization then do the following: # Calculate the equalization scale, update the observers with the scaled # inputs, and scale the weight weight_eq_obs_dict = update_obs_for_equalization(model, modules) convert_eq_obs(model, modules, weight_eq_obs_dict) # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops run_weight_observers(model) quantized_graph = Graph() env: Dict[str, Dict[Optional[torch.dtype], Node]] = defaultdict( lambda: defaultdict(Node)) # type: ignore[arg-type] graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def load_non_quantized(n: Node) -> Node: assert n.name in env, \ 'trying to load float node but did not find ' + \ 'node:' + n.name + \ ' in env: ' + \ str(env) dtype_to_node = env[n.name] if torch.float in dtype_to_node: return dtype_to_node[torch.float] elif None in dtype_to_node: return dtype_to_node[None] else: quantized_node = None for dtype in [torch.quint8, torch.qint8, torch.float16]: if dtype in dtype_to_node: quantized_node = dtype_to_node[dtype] break assert quantized_node is not None, "Did not find a supported quantized dtype:{}".format( dtype_to_node) env[n.name][torch.float] = Proxy(quantized_node).dequantize().node return env[n.name][torch.float] def load_quantized(dtype: torch.dtype): def load_quantized_impl(n: Node): assert n.name in env, \ 'trying to load quantized node but did not find node:' + \ n.name + ' in environment:' + str(env) dtype_to_node = env[n.name] local_dtype: Optional[torch.dtype] = dtype if local_dtype == torch.float and local_dtype not in dtype_to_node: local_dtype = None if local_dtype in [torch.float, None]: return load_non_quantized(n) assert local_dtype in dtype_to_node, f'Expecting {dtype} in {dtype_to_node}' return dtype_to_node[local_dtype] return load_quantized_impl def load_x(n: Node) -> Node: assert n.name in env, \ 'node ' + n.name + ' does not exist in environment' dtype_to_node = env[n.name] dtypes = [ torch.quint8, torch.qint8, torch.float16, torch.float32, None ] for dtype in dtypes: if dtype in dtype_to_node: return dtype_to_node[dtype] raise Exception( f'dtype {dtype} not found in environment: {dtype_to_node} for node {n.name}' ) def load_arg( quantized: Optional[Union[List[int], Dict[int, torch.dtype], torch.dtype, Tuple[int, ...]]] ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, torch.dtype, list or tuple - if quantized is None, then we'll load the node as long as it exists - if quantized is a dtype, then all args will be quantized to the specific dtype - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=torch.float) - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized to torch.quint8 Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, dict, torch.dtype)), type(quantized) if isinstance(quantized, (tuple, list, dict)) and len(quantized) == 0: # empty tuple or list means nothing is quantized quantized = torch.float def load_arg_impl(arg_or_args): # we'll update the format of `quantized` # to better match arg_or_args updated_quantized: Optional[Union[List[int], torch.dtype, Dict[int, torch.dtype], Tuple[int, ...]]] = quantized if isinstance(quantized, (tuple, list)) and \ len(quantized) == 1 and isinstance(arg_or_args, Node): # when argument is one Node instead of tuple, we just need to check # 0 is in the quantized list if 0 in quantized: updated_quantized = torch.quint8 if updated_quantized is None: return map_arg(arg_or_args, load_x) if isinstance(updated_quantized, torch.dtype): return map_arg(arg_or_args, load_quantized(updated_quantized)) elif isinstance(updated_quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in updated_quantized: # Currently it's hardcoded to torch.quint8, we can extend this # in the future to support all quantized # dtypes loaded_args.append( map_arg(a, load_quantized(torch.quint8))) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) elif isinstance(updated_quantized, dict): loaded_args = [] for i, a in enumerate(arg_or_args): if i in updated_quantized: loaded_args.append( map_arg(a, load_quantized(updated_quantized[i]))) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def node_arg_is_quantized(node_arg: Any) -> bool: if isinstance(node_arg, Node): assert node_arg.name in env, \ 'Expecting node_arg to be in the environment' if node_arg.name in env: dtype_to_node = env[node_arg.name] return any([ x in dtype_to_node for x in [torch.quint8, torch.qint8, torch.float16] ]) else: return False elif isinstance(node_arg, list): quantized = map(node_arg_is_quantized, node_arg) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") else: return False def is_output_quantized(node: Node, obj: QuantizeHandler, qconfig: QConfigAny, modules: Dict[str, torch.nn.Module]) -> bool: """ Check if output node is quantized or not """ assert modules is not None # for some ops the output is quantized only when `is_reference` is True # and when `is_reference` is False, it has limited qconfig # support, for example `add` # ideally this check should not happen here, it should happen either in # prepare or during lowering, we don't need this check # after the default path is changed to produce reference patterns quantized = obj.is_output_quantized(qconfig) # Need to get correct quantized/non-quantized state forn the output # of FixedQParamsQuantizeHandler # TODO: we may want to try to remove the special case here # as well if obj.should_mark_output_quantized_from_input_quantized_status( qconfig): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled' # TODO: need to extend this to consider all relevant args instead of just arg[0] quantized = node_arg_is_quantized(node.args[0]) # the output is unquantized if the node is not a CopyNode # or the activation is not statically quantized if not activation_is_statically_quantized(qconfig) or \ not obj.input_output_observed(): quantized = False if node_return_type_is_int(node): quantized = False return quantized def insert_quantize_node(node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Given a activation_post_process module call node, insert a quantize node""" assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float32: # copy the observer for fp32 dtype env[node.name][torch.float] = quantized_graph.node_copy( node, load_non_quantized) elif isinstance(prev_node, Node) and prev_node.name in env: # if previous node is already quantized, we'll just remove the # activation_post_process prev_dtype_to_node: Dict[Optional[torch.dtype], Node] = env[prev_node.name] current_dtype: Optional[ torch. dtype] = observer_module.dtype # type: ignore[assignment] if current_dtype in prev_dtype_to_node: env[node. name][current_dtype] = prev_dtype_to_node[current_dtype] else: root_module = modules[""] assert isinstance(prev_node, Node) observer_dtype: torch.dtype = observer_module.dtype # type: ignore[assignment] env[node.name][observer_dtype] = \ quantize_node( load_non_quantized(prev_node), observer_module, node, modules, quantized_graph, node_name_to_scope, is_input=True) else: # replace activation post process with quantization ops root_module = modules[""] assert isinstance(node.args[0], Node) dtype: torch.dtype = observer_module.dtype # type: ignore[assignment] env[node.name][dtype] = \ quantize_node( load_non_quantized(node.args[0]), observer_module, node, modules, quantized_graph, node_name_to_scope, is_input=True) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) for node in model.graph.nodes: if node.op == "output": cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: # Result are kept quantized if the user specified the # output_quantized_idxs override. graph_output = map_arg(node.args[0], load_x) else: graph_output = map_arg(node.args[0], load_non_quantized) quantized_graph.output(graph_output) continue root_node, matched, matched_pattern, obj, qconfig = \ matches.get(node.name, (None, None, None, None, None)) if root_node is node: is_observed_standalone_module_node = ( node.op == 'call_module' and is_observed_standalone_module(modules[node.target])) if qconfig is None and not is_observed_standalone_module_node: result = quantized_graph.node_copy(node, load_non_quantized) quantized = False # If there are QAT swapped modules in the graph that we don't want to quantize, rever them back to FP32 ones. if node.op == 'call_module' and type(modules[ node.target]) in DEFAULT_QAT_MODULE_MAPPINGS.values(): float_mod = modules[node.target].to_float() setattr(model, node.name, float_mod) with model.graph.inserting_before(node): new_float_node = model.graph.create_node( 'call_module', node.name, node.args, node.kwargs) else: assert obj is not None # We will get whether the output is quantized or not before # convert for standalone module and after convert # for non-standalone module, since _standalone_module_output_quantized_idxs # is only available in observed standalone module if is_observed_standalone_module_node: out_quant_idxs = modules[ node. target]._standalone_module_output_quantized_idxs.tolist( ) # noqa: B950 assert len( out_quant_idxs ) <= 1, "Currently standalone only support one output" quantized = 0 in out_quant_idxs qconfig = qconfig_map[node.name] # Note: load_arg can be overwritten in the convert method when used to # create Node in graph result = obj.convert( node, qconfig, modules, quantized_graph, node_name_to_scope, load_arg, is_reference=is_reference, convert_custom_config_dict=convert_custom_config_dict) if not is_observed_standalone_module_node: quantized = is_output_quantized(node, obj, qconfig, modules) if quantized: env[node.name][activation_dtype(qconfig)] = result else: env[node.name][torch.float] = result continue elif root_node is not None: if qconfig is None: # This branch is hit if all of these conditions are met: # 1. we are in a fusion pattern of multiple nodes (i.e. add-relu) # 2. the current node is not the "root_node" of the pattern # 3. quantization for this pattern is disabled # # In this case, we need to make sure to populate the env with # intermediate nodes manually, because the QuantizeHandler.convert # function will not be called. result = quantized_graph.node_copy(node, load_non_quantized) env[node.name][torch.float] = result continue # handle activation post process calls if node.op == 'call_module' and \ is_activation_post_process(modules[node.target]): insert_quantize_node(node, modules) elif node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: env[node.name][torch.quint8] = quantized_graph.node_copy( node, load_non_quantized) else: env[node.name][torch.float] = \ quantized_graph.node_copy(node, load_non_quantized) else: # copy quantized or non-quantized node # get_tensor_info_node like shape works for both # quantized and non-quantized input and output a non-Tensor # (we use None for dtype currently for non-Tensors) if is_get_tensor_info_node(node): env[node.name][None] = \ quantized_graph.node_copy(node, load_x) else: env[node.name][torch.float] = \ quantized_graph.node_copy(node, load_non_quantized) # remove activation post process act_post_process_removed_graph = Graph() remove_env: Dict[str, Node] = {} def load_arg_remove(a: Argument) -> Argument: return map_arg(a, lambda node: remove_env[node.name]) for node in quantized_graph.nodes: if node.op == 'output': act_post_process_removed_graph.output( map_arg(node.args[0], load_arg_remove)) continue if node.op == 'call_module' and \ is_activation_post_process(modules[node.target]): # remove activation post process node remove_env[node.name] = remove_env[node.args[0].name] else: remove_env[node.name] = act_post_process_removed_graph.node_copy( node, load_arg_remove) # removes qconfig and activation_post_process modules if _remove_qconfig_flag: _remove_qconfig(model) preserved_attributes = set( convert_custom_config_dict.get("preserved_attributes", [])) model = QuantizedGraphModule(model, act_post_process_removed_graph, preserved_attributes) if not is_reference: model = duplicate_dequantize_node(model) model = fold_weight(model, node_name_to_scope) model = lower_to_fbgemm(model) model = remove_quant_dequant_pairs(model) model = remove_extra_dequantize(model) return model
def test_type_check_conv2D_2_fully_static(self): annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)] input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (1, 2, 2, 3)] intermediate_types = [(1, Dyn, Dyn, 7), (2, Dyn, 4, 6), (10, 15, Dyn, 5), (10, 15, 7, 7), (1, Dyn, Dyn, Dyn)] in_planes_list = [2, 5, 15, 15, 2] stride_list = [1, 2, 3, 2, 2] out_planes_list = [2, 5, 15, 15, 2] groups_list = [1, 5, 5, 5, 2] dilation_list = [1, 2, 3, 3, 3] padding_list = [1, 2, 3, 3, 3] kernel_size_list = [1, 2, 3, 3, 3] output_types = [(1, 2, Dyn, 7), (2, 5, 4, 6), (10, 15, Dyn, 5), (10, 15, 7, 7), (1, 2, Dyn, Dyn)] for i in range(5): annotation = annotation_list[i] input = input_list[i] in_planes = in_planes_list[i] stride = stride_list[i] out_planes = out_planes_list[i] groups = groups_list[i] dilation = dilation_list[i] padding = padding_list[i] kernel_size = kernel_size_list[i] intermediate_type = intermediate_types[i] class BasicBlock(torch.nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): super(BasicBlock, self).__init__() self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False, dilation=dilation) def forward(self, x): out = self.conv1(x) return out B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = TensorType(annotation) b = B.forward(torch.rand(input)) tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'output': assert is_consistent(n.type, TensorType(b.size())) # test with intermediate annotations class BasicBlock(torch.nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): super(BasicBlock, self).__init__() self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False, dilation=dilation) def forward(self, x): out = self.conv1(x) return out B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # populate our intermediate notes for n in traced.graph.nodes: if n.op == 'call_module': n.type = TensorType(intermediate_type) tc = GraphTypeChecker({}, traced) tc.type_check() for n in traced.graph.nodes: if n.op == 'output': assert n.type == TensorType(output_types[i]) assert is_consistent(n.type, TensorType(b.size()))
def min_cut_rematerialization_partition( joint_module: fx.GraphModule, _joint_inputs, compiler="nvfuser") -> Tuple[fx.GraphModule, fx.GraphModule]: """ Partitions the joint graph such that the backward recomputes the forward. Recomputing helps in trading off memory bandwidth with computation. To create the fwd and bwd graph, we copy the joint graph, manually set the outputs to just original forward or backward outputs. And then we run the resulting graphs through dead code elimintation. .. warning:: This API is experimental and likely to change. Args: joint_module(fx.GraphModule): The joint forward and backward graph. This is the result of AOT Autograd tracing. Returns: Returns the generated forward and backward Fx graph modules. """ try: import networkx as nx except ImportError: raise RuntimeError( "Need networkx installed to perform smart recomputation heuristics" ) joint_module.graph.eliminate_dead_code() joint_module.recompile() fx_g = joint_module.graph # add the CSE pass cse_graph = fx_graph_cse(fx_g) joint_module.graph = cse_graph full_bw_graph = joint_module.graph name_to_node = {} for node in joint_module.graph.nodes: name_to_node[node.name] = node def classify_nodes(joint_module): required_bw_nodes = set() for node in joint_module.graph.nodes: if node.op == 'placeholder' and "tangents" in node.target: required_bw_nodes.add(node) if node in required_bw_nodes: for user in node.users: required_bw_nodes.add(user) primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) fwd_outputs, _ = _extract_fwd_bwd_outputs(joint_module) forward_only_graph = _extract_graph_with_inputs_outputs( joint_module.graph, primal_inputs, fwd_outputs) required_fw_nodes = { name_to_node[node.name] for node in forward_only_graph.nodes if node.op != 'output' } unclaimed_nodes = { node for node in joint_module.graph.nodes if node not in required_fw_nodes and node not in required_bw_nodes } return required_fw_nodes, required_bw_nodes, unclaimed_nodes required_fw_nodes, required_bw_nodes, unclaimed_nodes = classify_nodes( joint_module) for node in reversed(joint_module.graph.nodes): if node not in required_fw_nodes: node.dist_from_bw = 0 else: node.dist_from_bw = int(1e9) for user in node.users: node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) aten = torch.ops.aten prims = torch.ops.prims pointwise_ops = [ aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward ] # noqa: E501 if compiler == "inductor": pointwise_ops += [ prims.div, prims.convert_element_type, aten.sign, aten.clone ] # noqa: E501 misc_ops = [aten.to, aten.type_as, operator.getitem] reduction_ops = [ aten.softmax, aten._softmax, aten._softmax_backward_data, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax ] # noqa: E501 if compiler == "inductor": reduction_ops += [prims.var, prims.sum, aten.var] # not recomputed by default since these are kinda expensive/hard to fuse into # norm_ops = [aten.instance_norm, aten._batch_norm_impl_index, aten.native_batch_norm, aten.batch_norm, aten._batch_norm_impl_index_backward, aten.native_layer_norm, aten.layer_norm, aten.native_layer_norm_backward] # noqa: E501 # Not used by default since NVFuser can't fuse view ops # view_ops = [aten.expand, aten.clone, aten.transpose, aten.t, aten.view, aten._unsafe_view, aten.permute, aten.transpose, aten.t, aten._reshape_alias, aten.squeeze, aten.unsqueeze, aten.reshape, aten.cat, aten.slice, aten.split, aten.select, aten.repeat] # noqa: E501 # These are the view ops that NVFuser can fuse view_ops = [aten.squeeze, aten.unsqueeze] if compiler == "inductor": view_ops += [ prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors ] # noqa: E501 random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] compute_intensive_ops = [ aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d ] # noqa: E501 unrecomputable_ops = random_ops + compute_intensive_ops recomputable_ops = set(pointwise_ops + misc_ops + reduction_ops + view_ops) fusible_ops = recomputable_ops | set(random_ops) if AOT_PARTITIONER_DEBUG: joint_module_ops = set( str(node.target._overloadpacket) for node in joint_module.graph.nodes if node.op == "call_function" and hasattr(node.target, "_overloadpacket")) ops_ignored = joint_module_ops - set( [str(i) for i in recomputable_ops]) print("Ops banned from rematerialization: ", ops_ignored) print() AGGRESSIVE_RECOMPUTATION = False def _maybe_size_of(node): if 'tensor_meta' in node.meta: return _size_of(node.meta['tensor_meta']) return 0 def ban_recomputation(node): if AGGRESSIVE_RECOMPUTATION: return (node.op == 'call_function' and get_aten_target(node) in unrecomputable_ops) else: if node.op != 'call_function': return False if get_aten_target(node) not in recomputable_ops: return True if node.target == operator.getitem: return False if compiler == "inductor" and node.dist_from_bw > 4: return True # If the output of an op is 4x smaller (arbitrary choice), # then we don't allow recomputation. if 'tensor_meta' not in node.meta: return False input_tensors_size = sum( _maybe_size_of(i) for i in node.args if isinstance(i, fx.Node)) output_size = _size_of(node.meta['tensor_meta']) return (output_size * 4 < input_tensors_size) def is_fusible(a, b): return get_aten_target(a) in fusible_ops and get_aten_target( b) in fusible_ops def is_materialized(node): if node.op == 'placeholder': return True return not all(is_fusible(node, user) for user in node.users) def get_node_weight(node): mem_sz = _size_of(node.meta['tensor_meta']) # Heuristic to bias towards nodes closer to the backwards pass # Complete guess about current value mem_sz = int(mem_sz * (1.1**max(min(node.dist_from_bw, 100), 1))) # mem_sz = int(mem_sz + node.dist_from_bw) if is_materialized(node): return mem_sz else: return mem_sz * 2 nx_graph = nx.DiGraph() for node in full_bw_graph.nodes: if node.op == 'output': continue if node in required_bw_nodes: nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) continue if node.op == 'placeholder' and "primals" in node.target: nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) # If a node can't be recomputed (too expensive or involves randomness), # we prevent it from being recomputed by adding an inf edge to the source # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. if ban_recomputation(node) and node in required_fw_nodes: nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) if 'tensor_meta' not in node.meta: weight = math.inf else: weight = get_node_weight(node) # Creates the weights on the "node" edge nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) for user in node.users: nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf) cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") reachable, non_reachable = partition cutset = set() for u, nbrs in ((n, nx_graph[n]) for n in reachable): cutset.update((u, v) for v in nbrs if v in non_reachable) cut_nodes = set() for node_in, node_out in cutset: assert node_in[:-3] == node_out[:-4] node_name = node_in[:-3] cut_nodes.add(node_name) # To make this stuff deterministic node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x]) fw_module, bw_module = _extract_fwd_bwd_modules(joint_module, saved_values) if AOT_PARTITIONER_DEBUG: print( "Theoretical Activations Stored: ", sum([_size_of(i.meta['tensor_meta']) for i in saved_values]) / 1e9) fw_module_nodes = set([ node.name for node in fw_module.graph.nodes if node.op == 'call_function' ]) bw_module_nodes = set([ node.name for node in bw_module.graph.nodes if node.op == 'call_function' ]) remat_nodes = fw_module_nodes & bw_module_nodes counts = defaultdict(int) for node in fw_module.graph.nodes: if node.name in remat_nodes and hasattr(node.target, '_overloadpacket'): counts[str(node.target._overloadpacket)] += 1 print("# nodes rematerialized: ", len(remat_nodes)) print("Count of Ops Rematerialized: ", sorted(counts.items(), key=lambda x: x[1], reverse=True)) return fw_module, bw_module
def test_typecheck_basicblock(self): class BasicBlock(torch.nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1): super(BasicBlock, self).__init__() norm_layer = torch.nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError( 'BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError( "Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = torch.nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x: TensorType((2, 2, 4, 5))): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out B = BasicBlock(2, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() for n in traced.graph.nodes: if n.target == 'output': assert isinstance(n.type, TensorType) assert torch.Size(n.type.__args__) == B.forward( torch.rand(2, 2, 4, 5)).size()
def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant): if not inplace: model = copy.deepcopy(model) self.is_dynamic_quant = is_dynamic_quant if self.is_dynamic_quant: self.patterns = get_dynamic_quant_patterns() else: self.patterns = get_quant_patterns() propagate_qconfig_(model, qconfig_dict) if model.training: self._qat_swap_modules(model) self.modules = dict(model.named_modules()) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph) # match the patterns that will get quantized matches = self._find_matches(model.graph, self.modules, self.patterns) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuant object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed_node_names_set = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in model.graph.nodes: if node.name in observed_node_names_set: continue prefix = node.name + '_activation_post_process_' root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) if qconfig is None: continue def insert_observer(node, observer, device): get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) self.activation_post_process_map[node.name] = observer env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) if device: getattr(model, observer_name).to(device) if isinstance(obj, CustomModuleQuantizeHandler): custom_module = self.modules[node.target] observed_custom_module_class = \ get_observed_custom_module_class(type(custom_module)) observed_custom_module = \ observed_custom_module_class.from_float(custom_module) mark_observed_custom_module(observed_custom_module, type(custom_module)) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_custom_module) # don't need to insert observer for output in dynamic quantization if self.is_dynamic_quant: continue # inserting observers for output of observed module, or mark the output # as observed if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed_node_names_set elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed_node_names_set.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: if node.args[0].name in observed_node_names_set: observed_node_names_set.add(node.name) elif qconfig is not None and obj.all_nodes: # observer for outputs new_observer = qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) insert_observer(node, new_observer, device) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.name not in observed_node_names_set and node.name in quants: get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) _, qconfig, is_weight = quants[node.name] if qconfig is not None: new_observer = \ qconfig.weight() if is_weight else qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) if device: new_observer.to(device) self.activation_post_process_map[node.name] = new_observer setattr(model, observer_name, self.activation_post_process_map[node.name]) env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) observed_graph.output(load_arg(model.graph.result)) model = GraphModule(model, observed_graph) self.save_state(model) return model
def test_type_maxpool2d_fully_static(self): annotation_list = [(Dyn, Dyn, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 10)] input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (2, 2, 10, 10)] intermediate_types = [(1, 2, Dyn, Dyn), (2, Dyn, 2, 4), (10, 15, Dyn, 2), (10, 15, 2, 3), (2, Dyn, Dyn, Dyn)] stride_list = [1, 2, 3, 2, 1] dilation_list = [1, 2, 3, 3, 2] padding_list = [1, 2, 3, 3, 1] kernel_size_list = [2, 4, 6, 6, 3] output_types = [(1, 2, 4, 6), (2, 5, 2, 4), (10, 15, 2, 2), (10, 15, 2, 3), (2, Dyn, Dyn, 8)] for i in range(5): annotation = annotation_list[i] input = input_list[i] stride = stride_list[i] dilation = dilation_list[i] padding = padding_list[i] kernel_size = kernel_size_list[i] intermediate_type = intermediate_types[i] class BasicBlock(torch.nn.Module): def __init__(self, kernel_size, stride, padding, dilation): super(BasicBlock, self).__init__() self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=False, ceil_mode=False) def forward(self, x): out = self.pool(x) return out B = BasicBlock(kernel_size, stride, padding, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = TensorType(annotation) b = B.forward(torch.rand(input)) tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'output': assert is_consistent(n.type, TensorType(b.size())) # test with intermediate annotations class BasicBlock(torch.nn.Module): def __init__(self, kernel_size, stride, padding, dilation): super(BasicBlock, self).__init__() self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=False, ceil_mode=False) def forward(self, x): out = self.pool(x) return out B = BasicBlock(kernel_size, stride, padding, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = TensorType(annotation) # populate our intermediate notes for n in traced.graph.nodes: if n.op == 'call_module': n.type = TensorType(intermediate_type) tc = GraphTypeChecker({}, traced) tc.type_check() for n in traced.graph.nodes: if n.op == 'output': assert n.type == TensorType(output_types[i]) assert is_consistent(n.type, TensorType(b.size()))
def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False, is_standalone_module=False): """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. For standalone module: the inputs will be quantized by parent module, checks `_standalone_module_observed_input_idxs` of input observed model and will treat these inputs as quantized also will not dequantize the final output. Returns a quantized standalone module which accepts quantized input(if needed) and produces quantized output (if needed). """ self.restore_state(model) if not inplace: model = copy.deepcopy(model) self.is_dynamic_quant = is_dynamic_quant # run weight observers before inserting quant dequant nodes # for dynamic quantization if self.is_dynamic_quant: self._run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() self.modules = dict(model.named_modules()) matches = self._find_matches(model.graph, self.modules, self.patterns) quants = self._find_quants(model.graph, matches) self.quantized_graph = Graph() env = {} quant_env = {} graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def load_non_quantized(n): if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find node:' + n.name + \ ' in quantized or non quantized environment, env: ' + str(env) + \ ' quant_env:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n): if n.name not in quant_env: assert n.name in env, \ 'trying to load quantized node but did not find node:' + n.name + \ ' in float environment:' + str(env) assert n.name in quants, 'did not find quant object for node:' + n.name quant = quants[n.name][0] quant_env[n.name] = quant.convert(self, env[n.name]) return quant_env[n.name] def load_x(n): assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized): """ Input: quantized, which can be None, list, boolean or tuple - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is None, then we'll load the node as long as it exists Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or isinstance( quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg_or_args): if quantized is None: return map_arg(arg_or_args, load_x) if isinstance(quantized, bool): return map_arg( arg_or_args, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def is_quantized(node): if isinstance(node, Node): assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment' # there might be nodes appearing in both environemnts, but quant_env will take # precedence if node.name in quant_env: return True elif node.name in env: return False elif isinstance(node, list): quantized = map(is_quantized, node) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") for node in model.graph.nodes: root_node, matched, obj, qconfig = matches.get( node.name, (None, None, None, None)) if root_node is node: if qconfig is None: result = self.quantized_graph.node_copy( node, load_non_quantized) quantized = False else: result = obj.convert(self, node, load_arg) if node.op == 'call_module' and is_observed_standalone_module( self.modules[node.target]): quantized = self.modules[ node.target]._output_is_observed else: quantized = True # Need to get correct quantized/non-quantized state for the output of CopyNode if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = is_quantized(node.args[0]) # output of dynamic quantization is not quantized if self.is_dynamic_quant: quantized = False if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module': if is_activation_post_process(self.modules[node.target]): observer_module = self.modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float16: # activations are not quantized for # fp16 dynamic quantization # copy the activaiton_post_process node here # since we may need it when we insert prepack # op for weight of linear, this will be removed # later in a separate pass env[node.name] = self.quantized_graph.node_copy( node, load_non_quantized) continue if prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the activation_post_process quant_env[node.name] = quant_env[prev_node.name] continue # replace activation post process with quantization ops root_module = self.modules[''] quant_env[node.name] = quantize_node( root_module, self.quantized_graph, load_non_quantized(node.args[0]), observer_module) continue if is_standalone_module and node.op == 'placeholder' and \ graph_inputs.index(node.name) in model._standalone_module_observed_input_idxs: # the node is quantized in parent module quant_env[node.name] = self.quantized_graph.node_copy( node, load_non_quantized) else: # dequantize inputs for the node that are not quantized env[node.name] = self.quantized_graph.node_copy( node, load_non_quantized) if is_standalone_module: # result are kepted quantized in the quantized standalone module graph_output = map_arg(model.graph.result, load_x) else: graph_output = map_arg(model.graph.result, load_non_quantized) self.quantized_graph.output(graph_output) # remove activation post process act_post_process_removed_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in self.quantized_graph.nodes: if node.op == 'call_module' and \ is_activation_post_process(self.modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] else: env[node.name] = act_post_process_removed_graph.node_copy( node, load_arg) act_post_process_removed_graph.output( map_arg(self.quantized_graph.result, load_arg)) module_dict = dict(model.named_modules()) to_be_removed = [] for name, module in model.named_modules(): if is_activation_post_process( module) and not is_submodule_of_fake_quant( name, module, module_dict): to_be_removed.append(name) for n in to_be_removed: delattr(model, n) _remove_qconfig(model) model = GraphModule(model, act_post_process_removed_graph) return model
def transform(traced): new_graph = copy.deepcopy(traced.graph) relu_out = new_graph.create_node( op='call_method', target='neg', args=(new_graph.result,), kwargs={}) new_graph.output(relu_out) return GraphModule(traced, new_graph)
def create_a_shadows_b( name_a: str, gm_a: GraphModule, name_b: str, gm_b: GraphModule, matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]], logger_cls: Callable, should_log_inputs: bool, node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, ) -> GraphModule: """ Creates a new GraphModule consisting of the graph of C, with the meaningful nodes of A shadowing the corresponding nodes of B. For example, Graph A: a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 Graph B: b0 -> op0_int8 -> b1 -> op1_int8 -> b2 matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} Graph C (A shadows B): / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 / / b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 In a nutshell, this function does the following for each node pair: * copies the necessary attributes and modules from gm_a to gm_b, keeping names unique * adds a dtype cast op (dequant, quant, etc) * adds a copy of node_a in gm_b's graph * adds loggers to the outputs of node_a and node_b """ if node_type_to_io_type_map is None: node_type_to_io_type_map = get_node_type_to_io_type_map() # graph_c is the graph created from copying the nodes of graph_b and inserting # the shadows with the nodes copied from graph_a graph_c = Graph() env_c: Dict[str, Any] = {} modules = dict(gm_b.named_modules()) def load_arg(a): return map_arg(a, lambda node: env_c[node.name]) start_node_b_to_matched_subgraph_a_and_name = {} end_node_b_to_matched_subgraph_a_and_name = {} for match_name, match in matched_subgraph_pairs.items(): subgraph_a, subgraph_b = match start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \ (subgraph_a, match_name) end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \ (subgraph_a, match_name) for node_b in gm_b.graph.nodes: if node_b.op == 'output': graph_c.output(map_arg(node_b.args[0], load_arg)) continue # calculate the flags to determine what to do with this node node_b_is_observer = \ node_b.op == 'call_module' and is_activation_post_process(modules[node_b.target]) node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name if node_b_is_observer: # remove activation post process node env_c[node_b.name] = env_c[node_b.args[0].name] elif (node_b_is_start_node or node_b_is_end_node): if node_b_is_start_node: subgraph_a, ref_name = \ start_node_b_to_matched_subgraph_a_and_name[node_b] else: assert node_b_is_end_node subgraph_a, ref_name = \ end_node_b_to_matched_subgraph_a_and_name[node_b] # For both start_node and end_node verify that we know how to do # the dtype cast. If we do not, skip. node_input_type_a, node_output_type_a = \ get_node_first_input_and_output_type( subgraph_a.start_node, gm_a, logger_cls, node_type_to_io_type_map) node_input_type_b, node_output_type_b = \ get_node_first_input_and_output_type( node_b, gm_b, logger_cls, node_type_to_io_type_map) node_io_types_known_a_and_b = ( node_input_type_a != NodeInputOrOutputType.UNKNOWN and node_output_type_a != NodeInputOrOutputType.UNKNOWN and node_input_type_b != NodeInputOrOutputType.UNKNOWN and node_output_type_b != NodeInputOrOutputType.UNKNOWN ) if not node_io_types_known_a_and_b: print( f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' + f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' + ', unknown dtype cast') env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) continue if node_b_is_start_node: # if necessary, log the input of node_c if should_log_inputs: if isinstance(node_b.args[0], Node): prev_node_c = env_c[node_b.args[0].name] env_c[prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0, index_of_arg=0) elif isinstance(node_b.args[0], list): # first, save the prev_node instances, because they # will be overwritten in the env after the first logger # is added prev_node_c_list = [env_c[arg.name] for arg in node_b.args[0]] for arg_idx, arg in enumerate(node_b.args[0]): prev_node_c = prev_node_c_list[arg_idx] env_c[prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx, index_of_arg=0) else: # logging of inputs which are not lists is not supported yet raise AssertionError(f"type {type(node_b.args[0])} is not handled yet") # subgraph so far: # # (prev_node_c)+ -> (logger_c_input)? # Note: this if statement is always True, spelling it out to clarify code # intent. if node_b_is_start_node or node_b_is_end_node: # ensure env_c is populated with base node env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) node_c = env_c[node_b.name] # after this point, # # node_a is the original node from graph_a, with parent module gm_a # node_b is the original node from graph_b, with parent module gm_b # node_c is the copy of node_b in graph_c # # subgraph so far: # # (prev_node_c)+ -> (logger_c_input)? -> node_start_c if node_b_is_start_node: # cast dtype from the dtype of node_c's input to the dtype of # node_a's input (dequant, etc) prev_node_c = node_c.args[0] if should_log_inputs: # skip the input logger when inserting a dtype cast if isinstance(prev_node_c, Node): prev_node_c = prev_node_c.args[0] elif isinstance(prev_node_c, list): prev_node_c = [arg.args[0] for arg in prev_node_c] dtype_cast_node = _insert_dtype_cast_after_node( subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c, node_b.name + '_dtype_cast_', logger_cls, node_type_to_io_type_map) # note: not inserting to env_c because all nodes which use the dtype # casts are copied from graph_a # # subgraph so far: # # (dtype_cast_node)+ # / # (prev_node_c)+ -> (logger_c_input)? -> node_start_c # if input logging is enabled, log the input to the subgraph if should_log_inputs: # TODO: explain this ref_node_name = '' if isinstance(dtype_cast_node, Node): dtype_cast_node = _insert_logger_after_node( dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_', ref_node_name, name_a, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0, index_of_arg=0) input_logger: Union[Node, List[Node]] = dtype_cast_node else: assert isinstance(dtype_cast_node, list) new_loggers = [] for dtype_cast_idx, dtype_cast_node_inner in enumerate(dtype_cast_node): dtype_cast_logger = _insert_logger_after_node( dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_', ref_node_name, name_a, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=dtype_cast_idx, index_of_arg=0) new_loggers.append(dtype_cast_logger) dtype_cast_node = new_loggers input_logger = dtype_cast_node # subgraph so far: # # (dtype_cast_node)+ -> (logger_a_input)? # / # prev_node_c -> (logger_c_input)? -> node_start_c # hook up the new mod_a copy to be in the graph, receiving the # same inputs as mod_b does, with dtype cast to match a # Some ops, such as LSTMs, have two non-param inputs. If we have # such an op, pass the second param as well. Note: dtype casting # for the second param is not implemented yet, it can be added # later if there is a use case. node_c_second_non_param_arg = None num_non_param_args_node_a = get_number_of_non_param_args(subgraph_a.start_node, gm_a) if num_non_param_args_node_a == 2: node_c_second_non_param_arg = node_c.args[1] node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( dtype_cast_node, node_c_second_non_param_arg, subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_') env_c[node_a_shadows_c.name] = node_a_shadows_c # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown) # / # (prev_node_c)+ -> (logger_c_input)? -> node_start_c if should_log_inputs: # When we created the input logger, we left the ref_node_name # as an empty string, because the subgraph copy did not exist # yet. Now that the subgraph copy exists, we modify this name # to its true value. # Note: the alternative to this is to create the input logger # after creating the subgraph, which is slightly more # complicated. This is the lesser of two evils. # input_logger = env_c[dtype_cast_node.name] # Find the first node in the subgraph cur_node = node_a_shadows_c while cur_node.args[0] != input_logger: cur_node = cur_node.args[0] # type: ignore[assignment] if isinstance(input_logger, Node): input_logger_mod = getattr(gm_b, input_logger.name) input_logger_mod.ref_node_name = cur_node.name else: assert isinstance(input_logger, list) for input_logger_inner in input_logger: input_logger_mod = getattr(gm_b, input_logger_inner.name) input_logger_mod.ref_node_name = cur_node.name # hook up a logger to the mod_a copy env_c[node_a_shadows_c.name] = _insert_logger_after_node( env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_', node_a_shadows_c.name, name_a, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0, index_of_arg=0) # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a # / # (prev_node_c)+ -> (logger_c_input)? -> node_start_c if node_b_is_end_node: # hook up a logger to the mod_b copy env_c[node_b.name] = _insert_logger_after_node( env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0, index_of_arg=0) # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a # / # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c # # Note: node_start_c may be the same node as node_end_c, or they # may have nodes inbetween. else: env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) gm_c = GraphModule(gm_b, graph_c) return gm_c
def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False): self.restore_state(model) if not inplace: model = copy.deepcopy(model) self.is_dynamic_quant = is_dynamic_quant # run weight observers before inserting quant dequant nodes # for dynamic quantization if self.is_dynamic_quant: self._run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() self.modules = dict(model.named_modules()) matches = self._find_matches(model.graph, self.modules, self.patterns) quants = self._find_quants(model.graph, matches) self.quantized_graph = Graph() env = {} quant_env = {} def load_non_quantized(n): if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find node:' + n.name + \ ' in quantized environment:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n): if n.name not in quant_env: assert n.name in env, \ 'trying to load quantized node but did not find node:' + n.name + \ ' in float environment:' + str(env) assert n.name in quants, 'did not find quant object for node:' + n.name quant = quants[n.name][0] quant_env[n.name] = quant.convert(self, env[n.name]) return quant_env[n.name] def load_x(n): assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized): """ Input: quantized, which can be None, list, boolean or tuple - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is None, then we'll load the node as long as it exists Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg_or_args): if quantized is None: return map_arg(arg_or_args, load_x) if isinstance(quantized, bool): return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def is_quantized(node): if isinstance(node, Node): assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment' # there might be nodes appearing in both environemnts, but quant_env will take # precedence if node.name in quant_env: return True elif node.name in env: return False elif isinstance(node, list): quantized = map(is_quantized, node) if all(quantized): return True elif not any(quantized): return False else: raise Exception("partially quantized inputs in list not handled yet") for node in model.graph.nodes: root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is node: result = obj.convert(self, node, load_arg) quantized = True # Need to get correct quantized/non-quantized state for the output of CopyNode if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = is_quantized(node.args[0]) # output of dynamic quantization is not quantized if self.is_dynamic_quant: quantized = False if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module': if node.target.split('.')[-1].startswith('activation_post_process_'): observer_module = self.modules[node.target] prev_node = node.args[0] if prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the activation_post_process quant_env[node.name] = quant_env[prev_node.name] continue # replace activation post process with quantization ops root_module = self.modules[''] quant_env[node.name] = quantize_node( root_module, self.quantized_graph, load_non_quantized(node.args[0]), observer_module) continue # dequantize inputs for the node that are not quantized env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) self.quantized_graph.output(map_arg(model.graph.result, load_non_quantized)) to_be_removed = [] for name, _ in model.named_modules(): if name.split('.')[-1].startswith('activation_post_process_'): to_be_removed.append(name) for n in to_be_removed: delattr(model, n) model = GraphModule(model, self.quantized_graph) return model
def insert_observers_for_model( model: GraphModule, modules: Dict[str, torch.nn.Module], matches: Dict[str, MatchResult], qconfig_map: Dict[str, QConfigAny], graph: Graph, prepare_custom_config_dict: Dict[str, Any], input_quantized_idxs: List[int], output_quantized_idxs: List[int], ) -> Optional[Node]: """ Inserts observers, using the following high level algorithm: For each node in the graph: 1. determine the target dtype of this node in the quantized graph, and save it for future steps 2. determine the target dtype or all args and kwargs of this node 3. if any arg or kwarg's target dtype does not match the current node's dtype, insert an observer 4. if the current node needs an output observer, insert it For example: - starting graph: x0 -> linear -> x1 - observed graph after processing x0: x0(fp32) - observed graph after processing linear: x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) - observed graph after processing x1: x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1 After a node is processed, the naive observer placement is guaranteed to be complete for that node and all of its predecessors. There can be future passes which optimize the graph by deduplicating observers, etc. """ node_name_to_target_dtype: Dict[str, Any] = {} cache_for_no_tensor_check: Dict[Node, bool] = dict() inputs_seen_counter = 0 outputs_seen_counter = 0 results_node = None # first, populate the dtype map based only on qconfig and qhandler # this assumes: # graph inputs are fp32 by default, and int8 where overriden # other nodes output dtype is specified by the qconfig modules = dict(model.named_modules(remove_duplicate=False)) for node in model.graph.nodes: root_node, matched_nodes, pattern, qhandler, qconfig = matches.get( node.name, (None, None, None, None, None)) node_name_to_target_dtype[ node.name] = get_target_activation_dtype_for_node( node, qconfig, inputs_seen_counter, outputs_seen_counter, input_quantized_idxs, output_quantized_idxs, qhandler, modules, cache_for_no_tensor_check) # Second, for nodes with known input dtypes, propagate them throughout the # graph. For example, if there is a call such as # x1 = x0.masked_fill(mask, 1) # we propagate the type of mask to be torch.bool propagate_dtypes_for_known_nodes(model.graph, node_name_to_target_dtype, matches) # After this point, the current node and all of its arguments # have a dtype assigned. Now, we insert observers for inputs # of this node (if needed for this node), and the output of this node # (if needed for this node). # Since we are mutating the graph as we go, we iterate over the original # nodes before observer insertion, instead of model.graph.nodes. nodes_before_observation = list(model.graph.nodes) for node in nodes_before_observation: # check for matches root_node, matched_nodes, pattern, qhandler, qconfig = matches.get( node.name, (None, None, None, None, None)) if node.op == 'placeholder': # if a graph input is in fp32, it does not need observation # if a graph input is in int8, we assume the observation happens # outside of the graph, and no additional observation is needed pass elif node.op in ('call_module', 'call_method', 'call_function', 'output'): modules = dict(model.named_modules(remove_duplicate=False)) this_node_dtype = node_name_to_target_dtype[node.name] output_not_a_tensor = this_node_dtype is None # TODO(future PR): consider stopping matching getitem is_getitem = node.op == 'call_function' and \ node.target == operator.getitem skip_inserting_observers = ( (qconfig is None) or output_not_a_tensor or is_getitem) and (not node.op == 'output') if not skip_inserting_observers: if node.op != 'output': # this modifies node inplace maybe_insert_input_observers_for_node( node, qconfig, model, modules, graph, node_name_to_target_dtype, qhandler, prepare_custom_config_dict) is_last_node_of_pattern = root_node is node is_like_copy_node = \ (qhandler is not None and ( isinstance(qhandler, CopyNodeQuantizeHandler) )) if is_last_node_of_pattern and (not is_like_copy_node): # this returns the new observer node if it was needed maybe_output_obs_node = maybe_insert_output_observer_for_node( node, model, modules, graph, matches, node_name_to_target_dtype, pattern, qhandler) if maybe_output_obs_node is not None: # Update users of original node to use the output observer # instead. For example, change # # next_node # / # cur_node -> obs # # to # # next_node # / # cur_node -> obs # # We need to save orig users before updating uses because # the list of users will change as we update uses orig_users = list(node.users.keys()) for user_node in orig_users: if user_node is maybe_output_obs_node: continue user_node.replace_input_with( node, maybe_output_obs_node) # for quantized cat nodes only, we modify the graph # to make all inputs and outputs use the first input's # observer if isinstance(qhandler, CatQuantizeHandler): adjust_observers_for_cat(node, model, modules) if isinstance(qhandler, CustomModuleQuantizeHandler): swap_custom_module_to_observed( node, qconfig, modules, prepare_custom_config_dict) else: # output maybe_insert_observers_before_graph_output( node, output_quantized_idxs, node_name_to_target_dtype, qconfig_map, model, modules, graph) # # After this point, the current node has input and output observers # that it needs for itself inserted. # # increment the counters, so future inputs and outputs are assigned # correct dtypes if node.op == 'placeholder': inputs_seen_counter += 1 elif node.op == 'output': outputs_seen_counter += 1 results_node = node return results_node
def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant): assert not inplace, 'inplace prepare is not supported yet' input_root = model.root if not inplace: input_root = copy.deepcopy(input_root) input_graph = model.graph self.is_dynamic_quant = is_dynamic_quant # TODO: allow user specified patterns if self.is_dynamic_quant: self.patterns = get_dynamic_quant_patterns() else: self.patterns = get_quant_patterns() propagate_qconfig_(input_root, qconfig_dict) if input_root.training: self._qat_swap_modules(input_root) self.modules = dict(input_root.named_modules()) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(input_root, input_graph) # match the patterns that will get quantized matches = self._find_matches(input_graph, self.modules, self.patterns) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuant object for each quants = self._find_quants(input_graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: if node.name in observed: continue get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_') root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) def insert_observer(node, observer): observer_name = get_new_observer_name(input_root) setattr(input_root, observer_name, observer) self.activation_post_process_map[node.name] = observer env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) observed.add(node.name) # don't need to insert observer for output in dynamic quantization if self.is_dynamic_quant: continue if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: if node.args[0].name in observed: observed.add(node.name) elif qconfig is not None and obj.all_nodes: # observer for outputs insert_observer(node, qconfig.activation()) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.name not in observed and node.name in quants: observer_name = get_new_observer_name(input_root) _, qconfig, is_weight = quants[node.name] if qconfig is not None: self.activation_post_process_map[node.name] = qconfig.weight() if is_weight else qconfig.activation() setattr(input_root, observer_name, self.activation_post_process_map[node.name]) env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) observed.add(node.name) observed_graph.output(load_arg(input_graph.result)) observed = GraphModule(input_root, observed_graph) self.save_state(observed) return observed
def lower_to_elementwise_interpreter( orig_mod: torch.nn.Module) -> torch.nn.Module: # ===== Stage 1: Symbolic trace the module ===== mod = symbolic_trace(orig_mod) # ===== Stage 2: Lower GraphModule representation to the C++ # interpreter's instruction format ====== instructions = [] constant_idx = 0 constants = {} fn_input_names = [] target_to_name = {operator.add: "add", operator.mul: "mul"} output_node: Optional[Node] = None # For each instruction, create a triple # (instruction_name : str, inputs : List[str], output : str) # to feed into the C++ interpreter for n in mod.graph.nodes: target, args, out_name = n.target, n.args, n.name assert len(n.kwargs) == 0, "kwargs currently not supported" if n.op == 'placeholder': # Placeholders specify function argument names. Save these # for later when we generate the wrapper GraphModule fn_input_names.append(target) elif n.op == 'call_function': assert target in target_to_name, "Unsupported call target " + target arg_names = [] for arg in args: if not isinstance(arg, Node): # Pull out constants. These constants will later be # fed to the interpreter C++ object via add_constant() arg_name = f'constant_{constant_idx}' constants[arg_name] = torch.Tensor( [arg] if isinstance(arg, numbers.Number ) else arg) arg_names.append(arg_name) constant_idx += 1 else: arg_names.append(arg.name) instructions.append( (target_to_name[target], arg_names, out_name)) elif n.op == 'output': if output_node is not None: raise RuntimeError('Multiple output nodes!') output_node = n else: raise RuntimeError('Unsupported opcode ' + n.op) interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter( ) # Load constants for k, v in constants.items(): interpreter.add_constant(k, v) # Specify names for positional input arguments interpreter.set_input_names(fn_input_names) # Load instructions interpreter.set_instructions(instructions) # Specify name for single output assert isinstance(output_node.args[0], torch.fx.Node) interpreter.set_output_name(output_node.args[0].name) # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== class WrapperModule(torch.nn.Module): def __init__(self, interpreter): super().__init__() self.interpreter = interpreter wrapper = WrapperModule(interpreter) # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter # 3) Returns the speficied return value # FIXME: The following code could be greatly simplified by symbolic_trace'ing # the wrapper with a Tracer that considers the Wrapper instance a root # module, however, I can't get `__call__` exposed on TorchBind classes # without it messing up Python `hasattr` for some reason. More digging # into CPython's implementation of hasattr is probably in order... graph = torch.fx.Graph() # Add placeholders for fn inputs placeholder_nodes = [] for name in fn_input_names: placeholder_nodes.append(graph.create_node( 'placeholder', name)) # Get the interpreter object interpreter_node = graph.create_node('get_attr', 'interpreter') # Add a node to call the interpreter instance output_node = graph.create_node(op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes)) # Register output graph.output(output_node) graph.lint(wrapper) # Return final GraphModule!!! return GraphModule(wrapper, graph)
def remove_observers_add_loggers( gm: GraphModule, node_to_instrument_inputs_to_ref_node_name: Dict[Node, str], node_to_instrument_outputs_to_ref_node_name: Dict[Node, str], logger_cls: Callable, model_name: str, ) -> GraphModule: """ Takes the graph of gm, removes all observers, adds loggers to the output of each node in nodes_to_instrument. Returns a GraphModule with the new graph. """ new_graph = Graph() env: Dict[str, Any] = {} modules = dict(gm.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in gm.graph.nodes: if node.op == 'output': new_graph.output(map_arg(node.args[0], load_arg)) continue if node.op == 'call_module' and is_activation_post_process( modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] elif ((node in node_to_instrument_inputs_to_ref_node_name) or (node in node_to_instrument_outputs_to_ref_node_name)): if node in node_to_instrument_inputs_to_ref_node_name: ref_name = node_to_instrument_inputs_to_ref_node_name[node] if type(node.args[0]) == Node: # create a single input logger prev_node = env[node.args[0].name] env[node.args[0].name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0) elif type(node.args[0] ) == torch.fx.immutable_collections.immutable_list: # create N input loggers, one for each node for arg_idx, arg in enumerate(node.args[0]): prev_node = env[arg.name] env[prev_node.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx) else: raise AssertionError( f"type {type(node.args[0])} is not handled yet") # ensure env is populated with base node # Note: runs for both inputs and outputs env[node.name] = new_graph.node_copy(node, load_arg) if node in node_to_instrument_outputs_to_ref_node_name: ref_name = node_to_instrument_outputs_to_ref_node_name[node] # add the logger after the base node env[node.name] = _insert_logger_after_node( env[node.name], gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0) else: env[node.name] = new_graph.node_copy(node, load_arg) new_gm = GraphModule(gm, new_graph) return new_gm