def update_qconfig_for_fusion( model: GraphModule, qconfig_dict: Any, ) -> Any: """ Update the qconfig_dict to account for fused modules such as LinearReLU. """ object_type_dict = qconfig_dict.get("object_type", None) if object_type_dict is None: return qconfig_dict modules = dict(model.named_modules()) for node in model.graph.nodes: if node.op == 'call_module': module_type = type(modules[str(node.target)]) if module_type not in list(DEFAULT_OP_LIST_TO_FUSER_METHOD.values()): continue for ops, fuser in DEFAULT_OP_LIST_TO_FUSER_METHOD.items(): if module_type == fuser: fused_qconfig = object_type_dict.get(ops[0], None) # Raise an error if the modules in the fused module have # different qconfigs specified in the qconfig_dict for op in ops: if not qconfig_equals(object_type_dict.get(op, None), fused_qconfig): raise LookupError("During fusion, we need to specify the same " + f"qconfigs for both modules in {module_type}.") if fused_qconfig is not None: object_type_dict[module_type] = fused_qconfig return qconfig_dict
def get_matching_activations_a_shadows_b( gm_a_shadows_b: GraphModule, logger_cls: Callable, ) -> Dict[str, Dict[str, List[torch.Tensor]]]: """ Same thing as get_matching_activations, but for an `a_shadows_b` model. TODO(future PR): real docblock """ results: Dict[str, Dict[str, List[torch.Tensor]]] = \ collections.defaultdict(dict) for name, mod in gm_a_shadows_b.named_modules(): # TODO(future PR): better check when scripted is_logger = ( isinstance(mod, logger_cls) # type: ignore or ( isinstance(mod, torch.jit.RecursiveScriptModule) and mod.original_name == 'OutputLogger' ) ) if is_logger: # If logger_obj.other_node_name is populated, then this logger # is from model A, and other_node_name is the name from model B. if mod.other_node_name is None: results[mod.node_name + '.stats'][mod.model_name] = mod.stats else: results[mod.other_node_name + '.stats'][mod.model_name] = mod.stats return dict(results)
def _find_matches( self, root: GraphModule, graph: Graph, patterns: Dict[Pattern, Callable] ) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]]: modules = dict(root.named_modules()) match_map: Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]] = { } # node name -> (root_node, match_value) def apply_match(pattern, node, match, matched_node_pattern): if isinstance(pattern, tuple): s, *args = pattern current_node_pattern: List[Node] = [] apply_match(s, node, match, current_node_pattern) for subpattern, arg in zip(args, node.args): apply_match(subpattern, arg, match, current_node_pattern) matched_node_pattern.append(tuple(current_node_pattern)) else: # the first pattern matches will take precedence if node.name not in match_map: matched_node_pattern.append(node) root_node, pattern, handler = match match_map[node.name] = (root_node, pattern, matched_node_pattern, handler) for node in reversed(graph.nodes): if node.name not in match_map: for pattern, value in patterns.items(): matched_node_pattern: List[Node] = [] if is_match(modules, node, pattern): apply_match(pattern, node, (node, pattern, value(self, node)), matched_node_pattern) return match_map
def _find_matches( self, root: GraphModule, graph: Graph, patterns: Dict[Pattern, Callable] ) -> Dict[str, Tuple[Node, FuseHandler]]: modules = dict(root.named_modules()) match_map : Dict[str, Tuple[Node, FuseHandler]] = {} # node name -> (root_node, match_value) def apply_match(pattern, node, match): if isinstance(pattern, tuple): s, *args = pattern apply_match(s, node, match) for subpattern, arg in zip(args, node.args): apply_match(subpattern, arg, match) else: # the first pattern matches will take precedence if node.name not in match_map: match_map[node.name] = match for node in reversed(graph.nodes): if node.name not in match_map: for pattern, value in patterns.items(): if is_match(modules, node, pattern): apply_match(pattern, node, (node, value(self, node))) return match_map
def _convert_equalization_ref(model: GraphModule): """ Reference function which applies changes needed for equalization, but does not quantize the nodes """ modules = dict(model.named_modules(remove_duplicate=False)) # Calculate the equalization scale, update the observers with the scaled # inputs, and scale the weight weight_eq_obs_dict = update_obs_for_equalization(model, modules) convert_eq_obs(model, modules, weight_eq_obs_dict) return GraphModule(model, model.graph)
def add_activation_info_to_dict( model_name: str, model: GraphModule, results: Dict[str, Dict[str, List[torch.Tensor]]], logger_cls: Callable, ) -> None: for gm_name, mod in model.named_modules(): # TODO(future PR): better check when scripted is_logger = ( isinstance(mod, logger_cls) # type: ignore or (isinstance(mod, torch.jit.RecursiveScriptModule) and mod.original_name == 'OutputLogger')) if is_logger: key = mod.ref_node_name + '.stats' if key not in results: results[key] = {} results[key][model_name] = mod.stats
def _find_matches( root: GraphModule, graph: Graph, patterns: Dict[Pattern, Callable] ) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]: modules = dict(root.named_modules()) # node name -> (root_node, match_value) match_map: Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]] = {} # a map from node to the matched subpattern node_to_subpattern: Dict[Node, Any] = {} # TODO: dedup with quantization matching function in match_utils.py def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern): if isinstance(pattern, tuple): s, *args = pattern current_node_pattern: List[Node] = [] apply_match(s, node, match, current_node_pattern, node_to_subpattern) for subpattern, arg in zip(args, node.args): apply_match(subpattern, arg, match, current_node_pattern, node_to_subpattern) matched_node_pattern.append(tuple(current_node_pattern)) else: # the first pattern matches will take precedence if node.name not in match_map: matched_node_pattern.append(node) # MatchAllNode here is actually MatchAllInputNode which should not # be added to match_map if pattern is not MatchAllNode: node_to_subpattern[node] = pattern root_node, pattern, handler = match match_map[node.name] = (root_node, pattern, matched_node_pattern, handler, node_to_subpattern) for node in reversed(graph.nodes): if node.name not in match_map: for pattern, value in patterns.items(): matched_node_pattern: List[Node] = [] if is_match(modules, node, pattern): apply_match(pattern, node, (node, pattern, value(node)), matched_node_pattern, node_to_subpattern) break return match_map
def remove_observers_add_loggers( gm: GraphModule, node_to_instrument_to_ref_node_name: Dict[Node, Optional[str]], logger_cls: Callable, model_name: str, ) -> GraphModule: """ Takes the graph of gm, removes all observers, adds loggers to the output of each node in nodes_to_instrument. Returns a GraphModule with the new graph. """ new_graph = Graph() env: Dict[str, Any] = {} modules = dict(gm.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in gm.graph.nodes: if node.op == 'output': new_graph.output(map_arg(node.args[0], load_arg)) continue if node.op == 'call_module' and is_activation_post_process( modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] elif node in node_to_instrument_to_ref_node_name: other_node_name = node_to_instrument_to_ref_node_name[node] # ensure env is populated with base node env[node.name] = new_graph.node_copy(node, load_arg) # add the logger after the base node env[node.name] = _insert_logger_after_node(env[node.name], gm, logger_cls, '_ns_logger_', model_name, other_node_name) else: env[node.name] = new_graph.node_copy(node, load_arg) new_gm = GraphModule(gm, new_graph) return new_gm
def update_qconfig_for_fusion( model: GraphModule, qconfig_dict: Any, ) -> Any: """ Update the qconfig_dict to account for fused modules such as LinearReLU. """ object_type_dict = qconfig_dict.get("object_type", None) if object_type_dict is None: return qconfig_dict modules = dict(model.named_modules()) for node in model.graph.nodes: if node.op == 'call_module' and node.target in modules: maybe_fused_module = modules[str(node.target)] if not isinstance(maybe_fused_module, _FusedModule): continue ops = list(maybe_fused_module._modules.values()) fused_qconfig = object_type_dict.get(type(ops[0]), None) # Raise an error if the modules in the fused module have # different qconfigs specified in the qconfig_dict # TODO: currently it only works for modules, # need to make this work for torch.nn.functional.relu # TODO: currently it only works for object_type configurations, # ideally it should work for different types of configurations, # maybe we want to redesign this part for op in ops[1:]: if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig): raise LookupError( "During fusion, we need to specify the same " + f"qconfigs for all module types in {type(maybe_fused_module)} " + f"offending type: {type(op)}") if fused_qconfig is not None: object_type_dict[type(maybe_fused_module)] = fused_qconfig return qconfig_dict
def add_activation_info_to_dict( model: GraphModule, results: NSResultsType, logger_cls: Callable, ) -> None: for gm_name, mod in model.named_modules(): # TODO(future PR): better check when scripted is_logger = ( isinstance(mod, logger_cls) # type: ignore or (isinstance(mod, torch.jit.RecursiveScriptModule) and mod.original_name == 'OutputLogger')) if is_logger: key = mod.ref_name if key not in results: results[key] = {} results[key][mod.model_name] = { 'type': NSSingleResultValuesType.NODE_OUTPUT.value, 'values': mod.stats, 'node_name': mod.node_name, 'node_target_type': mod.node_target_type, }
def get_matching_activations_a_shadows_b( gm_a_shadows_b: GraphModule, logger_cls: Callable, ) -> NSResultsType: """ Same thing as get_matching_activations, but for an `a_shadows_b` model. TODO(future PR): real docblock """ results: NSResultsType = collections.defaultdict(dict) for name, mod in gm_a_shadows_b.named_modules(): # TODO(future PR): better check when scripted is_logger = ( isinstance(mod, logger_cls) # type: ignore or (isinstance(mod, torch.jit.RecursiveScriptModule) and mod.original_name == 'OutputLogger')) if is_logger: results[mod.ref_name][mod.model_name] = { 'type': NSSingleResultValuesType.NODE_OUTPUT.value, 'values': mod.stats, 'node_name': mod.node_name, 'node_target_type': mod.node_target_type, } return dict(results)
def _convert(self, model: GraphModule, debug: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False) -> GraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module which accepts float input and produces float output. """ if convert_custom_config_dict is None: convert_custom_config_dict = {} self.restore_state(model) # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops self._run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() self.modules = dict(model.named_modules()) custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") assert self.patterns is not None matches = self._find_matches( model.graph, self.modules, self.patterns, custom_module_classes=custom_module_classes) quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \ self._find_quants(model.graph, matches) self.quantized_graph = Graph() env: Dict[str, Node] = {} quant_env: Dict[str, Node] = {} graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def load_non_quantized(n: Node) -> Node: if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find ' + \ 'node:' + n.name + \ ' in quantized or non quantized environment, env: ' + \ str(env) + ' quant_env:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n: Node) -> Node: assert n.name in quant_env, \ 'trying to load quantized node but did not find node:' + \ n.name + ' in quant environment:' + str(quant_env) return quant_env[n.name] def load_x(n: Node) -> Node: assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized: Optional[Union[List[Any], bool, Tuple[Any, ...]]] ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, list, boolean or tuple - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is None, then we'll load the node as long as it exists Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg_or_args): if quantized is None: return map_arg(arg_or_args, load_x) if isinstance(quantized, bool): return map_arg( arg_or_args, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def node_arg_is_quantized(node_arg: Any) -> bool: if isinstance(node_arg, Node): assert node_arg.name in env or node_arg.name in quant_env, \ 'Expecting node_arg to be in the environment' # there might be nodes appearing in both environemnts, but # quant_env will take precedence if node_arg.name in quant_env: return True elif node_arg.name in env: return False else: return False elif isinstance(node_arg, list): quantized = map(node_arg_is_quantized, node_arg) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") else: return False def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool: """ Check if output node is quantized or not """ assert self.modules is not None # by default the output is expected to be quantized quantized = True # Need to get correct quantized/non-quantized state for the output # of CopyNode if type(obj) in [ CopyNode, FixedQParamsOpQuantizeHandler ]: assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = node_arg_is_quantized(node.args[0]) if not activation_is_statically_quantized(qconfig) or \ not input_output_observed(obj): quantized = False return quantized def insert_quantize_node(node: Node) -> None: """ Given a activation_post_process module call node, insert a quantize node""" assert self.modules is not None assert isinstance(node.target, str) observer_module = self.modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float16: # activations are not quantized for # fp16 dynamic quantization # copy the activaiton_post_process node here # since we may need it when we insert prepack # op for weight of linear, this will be removed # later in a separate pass env[node.name] = self.quantized_graph.node_copy( node, load_non_quantized) elif isinstance(prev_node, Node) and prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the # activation_post_process quant_env[node.name] = quant_env[prev_node.name] else: # replace activation post process with quantization ops root_module = self.modules[""] assert isinstance(node.args[0], Node) quant_env[node.name] = quantize_node( root_module, self.quantized_graph, load_non_quantized(node.args[0]), observer_module) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get( "output_quantized_idxs", []) for node in model.graph.nodes: if node.op == 'output': cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: # Result are kept quantized if the user specified the # output_quantized_idxs override. graph_output = map_arg(node.args[0], load_x) else: graph_output = map_arg(node.args[0], load_non_quantized) self.quantized_graph.output(graph_output) continue root_node, matched, matched_pattern, obj, qconfig = \ matches.get(node.name, (None, None, None, None, None)) if root_node is node: if qconfig is None: result = self.quantized_graph.node_copy( node, load_non_quantized) quantized = False else: assert obj is not None is_standalone_module_node = ( node.op == 'call_module' and is_observed_standalone_module( self.modules[node.target]) # type: ignore ) result = obj.convert( self, node, load_arg, debug=debug, convert_custom_config_dict=convert_custom_config_dict) if is_standalone_module_node: quantized = False else: quantized = is_output_quantized(node, obj) if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module' and \ is_activation_post_process(self.modules[node.target]): insert_quantize_node(node) elif node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: quant_env[node.name] = \ self.quantized_graph.node_copy(node, load_non_quantized) else: env[node.name] = \ self.quantized_graph.node_copy(node, load_non_quantized) else: # copy quantized or non-quantized node env[node.name] = \ self.quantized_graph.node_copy(node, load_non_quantized) # remove activation post process act_post_process_removed_graph = Graph() env = {} def load_arg_simple(a: Argument) -> Argument: return map_arg(a, lambda node: env[node.name]) for node in self.quantized_graph.nodes: if node.op == 'output': act_post_process_removed_graph.output( map_arg(node.args[0], load_arg_simple)) continue if node.op == 'call_module' and \ is_activation_post_process(self.modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] else: env[node.name] = act_post_process_removed_graph.node_copy( node, load_arg_simple) # removes qconfig and activation_post_process modules _remove_qconfig(model) model = GraphModule(model, act_post_process_removed_graph) return model
def create_a_shadows_b( name_a: str, gm_a: GraphModule, name_b: str, gm_b: GraphModule, matched_subgraph_pairs: Dict[str, Tuple[Tuple[Node, Node], Tuple[Node, Node]]], logger_cls: Callable, should_log_inputs: bool, ) -> GraphModule: """ Creates a new GraphModule consisting of the graph of C, with the meaningful nodes of A shadowing the corresponding nodes of B. For example, Graph A: a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 Graph B: b0 -> op0_int8 -> b1 -> op1_int8 -> b2 matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} Graph C (A shadows B): / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 / / b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 In a nutshell, this function does the following for each node pair: * copies the necessary attributes and modules from gm_a to gm_b, keeping names unique * adds a dtype cast op (dequant, quant, etc) * adds a copy of node_a in gm_b's graph * adds loggers to the outputs of node_a and node_b """ # graph_c is the graph created from copying the nodes of graph_b and inserting # the shadows with the nodes copied from graph_a graph_c = Graph() env_c: Dict[str, Any] = {} modules = dict(gm_b.named_modules()) def load_arg(a): return map_arg(a, lambda node: env_c[node.name]) node_b_to_matched_subgraph_a_and_name = {} for match_name, match in matched_subgraph_pairs.items(): (node_start_a, node_end_a), (node_start_b, node_end_b) = match assert node_start_b is node_end_b, \ "Shadowing subgraphs of B with multiple nodes is not yet handled." node_b_to_matched_subgraph_a_and_name[node_end_b] = \ ((node_start_a, node_end_a), match_name) for node_b in gm_b.graph.nodes: if node_b.op == 'output': graph_c.output(map_arg(node_b.args[0], load_arg)) continue if node_b.op == 'call_module' and is_activation_post_process( modules[node_b.target]): # remove activation post process node env_c[node_b.name] = env_c[node_b.args[0].name] # type: ignore elif node_b in node_b_to_matched_subgraph_a_and_name: (node_start_a, node_end_a), ref_name = \ node_b_to_matched_subgraph_a_and_name[node_b] if False: print('b') print_node(node_b) print('a') print_node(node_start_a) print_node(node_end_a) # if necessary, log the input of node_c if should_log_inputs: if isinstance(node_b.args[0], Node): prev_node_c = env_c[node_b.args[0].name] env_c[prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0) elif isinstance(node_b.args[0], list): # first, save the prev_node instances, because they # will be overwritten in the env after the first logger # is added prev_node_c_list = [ env_c[arg.name] for arg in node_b.args[0] ] for arg_idx, arg in enumerate(node_b.args[0]): prev_node_c = prev_node_c_list[arg_idx] env_c[prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx) else: # logging of inputs which are not lists is not supported yet raise AssertionError( f"type {type(node_b.args[0])} is not handled yet") # subgraph so far: # # (prev_node_c)+ -> (logger_c_input)? # ensure env_c is populated with base node env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) node_c = env_c[node_b.name] # after this point, # # node_a is the original node from graph_a, with parent module gm_a # node_b is the original node from graph_b, with parent module gm_b # node_c is the copy of node_b in graph_c # # subgraph so far: # # (prev_node_c)+ -> (logger_c_input)? -> node_c # cast dtype from the dtype of node_c's input to the dtype of # node_a's input (dequant, etc) dtype_cast_node = _insert_dtype_cast_after_node( node_start_a, node_c, node_c.args[0], gm_a, gm_b, graph_c, node_b.name + '_dtype_cast_') # note: not inserting to env_c because all nodes which use the dtype # casts are copied from graph_a # # subgraph so far: # # (dtype_cast_node)+ # / # (prev_node_c)+ -> (logger_c_input)? -> node_c # if input logging is enabled, log the input to the subgraph if should_log_inputs: # TODO: explain this ref_node_name = '' if isinstance(dtype_cast_node, Node): dtype_cast_node = _insert_logger_after_node( dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_', ref_node_name, name_a, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0) input_logger: Union[Node, List[Node]] = dtype_cast_node else: assert isinstance(dtype_cast_node, list) new_loggers = [] for dtype_cast_idx, dtype_cast_node_inner in enumerate( dtype_cast_node): dtype_cast_logger = _insert_logger_after_node( dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_', ref_node_name, name_a, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=dtype_cast_idx) new_loggers.append(dtype_cast_logger) dtype_cast_node = new_loggers input_logger = dtype_cast_node # subgraph so far: # # (dtype_cast_node)+ -> (logger_a_input)? # / # prev_node_c -> (logger_c_input)? -> node_c # hook up the new mod_a copy to be in the graph, receiving the # same inputs as mod_b does, with dtype cast to match a node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( dtype_cast_node, node_start_a, node_end_a, gm_a, gm_b, node_c.name + '_shadow_copy_') env_c[node_a_shadows_c.name] = node_a_shadows_c # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown) # / # (prev_node_c)+ -> (logger_c_input)? -> node_c if should_log_inputs: # When we created the input logger, we left the ref_node_name # as an empty string, because the subgraph copy did not exist # yet. Now that the subgraph copy exists, we modify this name # to its true value. # Note: the alternative to this is to create the input logger # after creating the subgraph, which is slightly more # complicated. This is the lesser of two evils. # input_logger = env_c[dtype_cast_node.name] # Find the first node in the subgraph cur_node = node_a_shadows_c while cur_node.args[0] != input_logger: cur_node = cur_node.args[0] # type: ignore if isinstance(input_logger, Node): input_logger_mod = getattr(gm_b, input_logger.name) input_logger_mod.ref_node_name = cur_node.name else: assert isinstance(input_logger, list) for input_logger_inner in input_logger: input_logger_mod = getattr(gm_b, input_logger_inner.name) input_logger_mod.ref_node_name = cur_node.name # hook up a logger to the mod_b copy env_c[node_b.name] = _insert_logger_after_node( env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', node_b.name, name_b, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0) # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy # / # (prev_node_c+) -> (logger_c_input)? -> node_c -> logger_c # hook up a logger to the mod_a copy # Note: we pass node_b.name to this logger, for easy matching later env_c[node_a_shadows_c.name] = _insert_logger_after_node( env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_', node_a_shadows_c.name, name_a, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0) # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a # / # (prev_node_c)+ -> (logger_c_input)? -> node_c -> logger_c else: env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) gm_c = GraphModule(gm_b, graph_c) return gm_c
def convert(model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False, _remove_qconfig_flag: bool = True) -> QuantizedGraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} patterns, node_name_to_scope, prepare_custom_config_dict = restore_state( model) qconfig_map: Dict[ str, QConfigAny] = model._qconfig_map # type: ignore[assignment] # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } # We use remove_duplicate=False here because torch.cat uses # the same activation_post_process module instance but different names modules = dict(model.named_modules(remove_duplicate=False)) custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") matches = find_matches(model.graph, modules, patterns, qconfig_map, custom_module_classes=custom_module_classes) quantized_graph = Graph() env: Dict[str, Tuple[Node, Optional[torch.dtype]]] = {} graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def load_non_quantized(n: Node) -> Node: assert n.name in env, \ 'trying to load float node but did not find ' + \ 'node:' + n.name + \ ' in env: ' + \ str(env) quantized_node, dtype = env[n.name] if dtype and dtype != torch.float: env[n.name] = Proxy(quantized_node).dequantize().node, torch.float return env[n.name][0] def load_quantized(n: Node) -> Node: assert n.name in env, \ 'trying to load quantized node but did not find node:' + \ n.name + ' in environment:' + str(env) quantized_node, dtype = env[n.name] assert dtype in [torch.quint8, torch.qint8, torch.float16], \ f'Expecting node {quantized_node} to be quantized but got dtype: {dtype}' return quantized_node def load_x(n: Node) -> Node: assert n.name in env, \ 'node ' + n.name + ' does not exist in environment' return env[n.name][0] def load_arg( quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, list, boolean or tuple - if quantized is None, then we'll load the node as long as it exists - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False) - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, bool)), type(quantized) if isinstance(quantized, (tuple, list)) and len(quantized) == 0: # empty tuple or list means nothing is quantized quantized = False def load_arg_impl(arg_or_args): # we'll update the format of `quantized` # to better match arg_or_args updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized if isinstance(quantized, (tuple, list)) and \ len(quantized) == 1 and isinstance(arg_or_args, Node): # when argument is one Node instead of tuple, we just need to check # 0 is in the quantized list updated_quantized = 0 in quantized if updated_quantized is None: return map_arg(arg_or_args, load_x) if isinstance(updated_quantized, bool): return map_arg( arg_or_args, load_quantized if updated_quantized else load_non_quantized) elif isinstance(updated_quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in updated_quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def node_arg_is_quantized(node_arg: Any) -> bool: if isinstance(node_arg, Node): assert node_arg.name in env, \ 'Expecting node_arg to be in the environment' if node_arg.name in env: _, dtype = env[node_arg.name] return dtype != torch.float else: return False elif isinstance(node_arg, list): quantized = map(node_arg_is_quantized, node_arg) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") else: return False def is_output_quantized(node: Node, obj: QuantizeHandler, qconfig: QConfigAny, modules: Dict[str, torch.nn.Module]) -> bool: """ Check if output node is quantized or not """ assert modules is not None # by default the output for a quantizable node is expected to be quantized quantized = True # Need to get correct quantized/non-quantized state forn the output # of FixedQParamsQuantizeHandler # TODO: we may want to try to remove the special case here # as well if obj.should_mark_output_quantized_from_input_quantized_status( qconfig): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled' # TODO: need to extend this to consider all relevant args instead of just arg[0] quantized = node_arg_is_quantized(node.args[0]) # the output is unquantized if the node is not a CopyNode # or the activation is not statically quantized if not activation_is_statically_quantized(qconfig) or \ not obj.input_output_observed(): quantized = False if node_return_type_is_int(node): quantized = False return quantized def insert_quantize_node(node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Given a activation_post_process module call node, insert a quantize node""" assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float32: # copy the observer for fp32 dtype env[node.name] = quantized_graph.node_copy( node, load_non_quantized), torch.float elif isinstance(prev_node, Node) and prev_node.name in env: # if previous node is already quantized, we'll just remove the # activation_post_process _, prev_dtype = env[prev_node.name] current_dtype = observer_module.dtype if prev_dtype == current_dtype: env[node.name] = env[prev_node.name] else: root_module = modules[""] assert isinstance(prev_node, Node) observer_dtype: torch.dtype = observer_module.dtype # type: ignore[assignment] env[node.name] = (quantize_node(load_non_quantized(prev_node), observer_module, node, modules, quantized_graph, node_name_to_scope, is_input=True), observer_dtype) else: # replace activation post process with quantization ops root_module = modules[""] assert isinstance(node.args[0], Node) dtype: torch.dtype = observer_module.dtype # type: ignore[assignment] env[node.name] = (quantize_node(load_non_quantized(node.args[0]), observer_module, node, modules, quantized_graph, node_name_to_scope, is_input=True), dtype) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) for node in model.graph.nodes: if node.op == "output": cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: # Result are kept quantized if the user specified the # output_quantized_idxs override. graph_output = map_arg(node.args[0], load_x) else: graph_output = map_arg(node.args[0], load_non_quantized) quantized_graph.output(graph_output) continue root_node, matched, matched_pattern, obj, qconfig = \ matches.get(node.name, (None, None, None, None, None)) if root_node is node: is_observed_standalone_module_node = ( node.op == 'call_module' and is_observed_standalone_module(modules[node.target])) if qconfig is None and not is_observed_standalone_module_node: result = quantized_graph.node_copy(node, load_non_quantized) quantized = False else: assert obj is not None # We will get whether the output is quantized or not before # convert for standalone module and after convert # for non-standalone module, since _standalone_module_output_quantized_idxs # is only available in observed standalone module if is_observed_standalone_module_node: out_quant_idxs = modules[ node. target]._standalone_module_output_quantized_idxs.tolist( ) # type: ignore[operator] # noqa: B950 assert len( out_quant_idxs ) <= 1, "Currently standalone only support one output" quantized = 0 in out_quant_idxs qconfig = qconfig_map[node.name] result = obj.convert( node, qconfig, modules, quantized_graph, node_name_to_scope, load_arg, is_reference=is_reference, convert_custom_config_dict=convert_custom_config_dict) if not is_observed_standalone_module_node: quantized = is_output_quantized(node, obj, qconfig, modules) if quantized: env[node.name] = result, activation_dtype(qconfig) else: env[node.name] = result, torch.float continue elif root_node is not None: if qconfig is None: # This branch is hit if all of these conditions are met: # 1. we are in a fusion pattern of multiple nodes (i.e. add-relu) # 2. the current node is not the "root_node" of the pattern # 3. quantization for this pattern is disabled # # In this case, we need to make sure to populate the env with # intermediate nodes manually, because the QuantizeHandler.convert # function will not be called. result = quantized_graph.node_copy(node, load_non_quantized) env[node.name] = result, torch.float continue # handle activation post process calls if node.op == 'call_module' and \ is_activation_post_process(modules[node.target]): insert_quantize_node(node, modules) elif node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: env[node.name] = \ quantized_graph.node_copy( node, load_non_quantized), torch.quint8 else: env[node.name] = \ quantized_graph.node_copy(node, load_non_quantized), torch.float else: # copy quantized or non-quantized node # get_tensor_info_node like shape works for both # quantized and non-quantized input and output a non-Tensor # (we use None for dtype currently for non-Tensors) if is_get_tensor_info_node(node): env[node.name] = \ quantized_graph.node_copy(node, load_x), None else: env[node.name] = \ quantized_graph.node_copy(node, load_non_quantized), torch.float # remove activation post process act_post_process_removed_graph = Graph() remove_env: Dict[str, Node] = {} def load_arg_remove(a: Argument) -> Argument: return map_arg(a, lambda node: remove_env[node.name]) for node in quantized_graph.nodes: if node.op == 'output': act_post_process_removed_graph.output( map_arg(node.args[0], load_arg_remove)) continue if node.op == 'call_module' and \ is_activation_post_process(modules[node.target]): # remove activation post process node remove_env[node.name] = remove_env[node.args[0].name] else: remove_env[node.name] = act_post_process_removed_graph.node_copy( node, load_arg_remove) # removes qconfig and activation_post_process modules if _remove_qconfig_flag: _remove_qconfig(model) preserved_attributes = set( convert_custom_config_dict.get("preserved_attributes", [])) model = QuantizedGraphModule(model, act_post_process_removed_graph, preserved_attributes) if not is_reference: model = fold_weight(model, node_name_to_scope) return model
def _prepare(self, model: GraphModule, qconfig_dict: Any, prepare_custom_config_dict: Optional[Dict[str, Any]], is_standalone_module: bool) -> GraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. When we are preparing a standalone module: both input and output are observed in prepared standalone module Returns: model(GraphModule): prepared standalone module """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} self.prepare_custom_config_dict = prepare_custom_config_dict additional_quant_patterns = \ prepare_custom_config_dict.get("additional_quant_pattern", {}) self.patterns = get_combined_dict( get_default_quant_patterns(), additional_quant_patterns) flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: additional_qat_module_mapping = prepare_custom_config_dict.get( "additional_qat_module_mapping", {}) self._qat_swap_modules(model, additional_qat_module_mapping) self.modules = dict(model.named_modules()) convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph, qconfig_dict) # match the patterns that will get quantized standalone_module_names = prepare_custom_config_dict.get( "standalone_module_name", None) standalone_module_classes = prepare_custom_config_dict.get( "standalone_module_class", None) custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") assert self.patterns is not None matches = self._find_matches( model.graph, self.modules, self.patterns, standalone_module_names, standalone_module_classes, custom_module_classes) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuantizeHandler object for each quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \ self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env: Dict[Any, Any] = {} observed_graph = Graph() observed_node_names_set: Set[str] = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) # indexes for the inputs that needs to be observed standalone_module_observed_input_idxs: List[int] = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) get_new_observer_name = get_new_attr_name_with_prefix( 'activation_post_process_') placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get( "output_quantized_idxs", []) result_node : Optional[Node] = None for node in model.graph.nodes: if node.op == 'output': # If this output is hardcoded to be quantized, insert an # observer on the previous node if it does not already # exist. cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: prev_node = node.args[0] assert isinstance(prev_node, Node), \ ('hardcoding list/dict outputs to be quantized is ' + 'not supported') if prev_node.name not in observed_node_names_set: assert self.qconfig_map is not None local_qconfig = self.qconfig_map[prev_node.name] assert local_qconfig is not None, \ 'qconfig of a node before a quantized output must exist' insert_observer( prev_node, local_qconfig.activation(), model, self.activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set) observed_graph.output(load_arg(node.args[0])) result_node = node continue if node.name in observed_node_names_set: continue root_node, matched_nodes, pattern, obj, qconfig = matches.get( node.name, (None, None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) # index for input of custom module that needs to be observed in # parent if qconfig is not None: assert obj is not None insert_observer_for_special_module( obj, self.modules, prepare_custom_config_dict, qconfig, node) insert_observer_for_output_of_the_node( node, obj, qconfig, self.modules, model, pattern, self.activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set, matched_nodes) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.op == 'placeholder': # skip adding observers at the graph input if the input is # overriden to be quantized cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: observed_node_names_set.add(node.name) continue insert_observer_for_input_arg_of_observed_node( node, observed_node_names_set, quants, model, self.activation_post_process_map, env, observed_graph, load_arg) model = GraphModule(model, observed_graph) self.save_state(model) model = mark_observed_module(model) return model
def _prepare(self, model: GraphModule, qconfig_dict: Any, prepare_custom_config_dict: Optional[Dict[str, Any]], is_standalone_module: bool) -> GraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. When we are preparing a standalone module: both input and output are observed in prepared standalone module Returns: model(GraphModule): prepared standalone module """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} additional_quant_patterns = \ prepare_custom_config_dict.get("additional_quant_pattern", {}) self.patterns = get_combined_dict(get_default_quant_patterns(), additional_quant_patterns) flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: additional_qat_module_mapping = prepare_custom_config_dict.get( "additional_qat_module_mapping", {}) self._qat_swap_modules(model, additional_qat_module_mapping) self.modules = dict(model.named_modules()) convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph, qconfig_dict) # match the patterns that will get quantized standalone_module_names = prepare_custom_config_dict.get( "standalone_module_name", None) standalone_module_classes = prepare_custom_config_dict.get( "standalone_module_class", None) custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") assert self.patterns is not None matches = self._find_matches(model.graph, self.modules, self.patterns, standalone_module_names, standalone_module_classes, custom_module_classes) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuantizeHandler object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env: Dict[Any, Any] = {} observed_graph = Graph() observed_node_names_set: Set[str] = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) # indexes for the inputs that needs to be observed standalone_module_observed_input_idxs: List[int] = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) get_new_observer_name = get_new_attr_name_with_prefix( 'activation_post_process_') model_device = assert_and_get_unique_device(model) result_node: Optional[Node] = None for node in model.graph.nodes: if node.op == 'output': observed_graph.output(load_arg(node.args[0])) result_node = node continue if node.name in observed_node_names_set: continue root_node, matched_nodes, pattern, obj, qconfig = matches.get( node.name, (None, None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) # index for input of custom module that needs to be observed in # parent if qconfig is not None: assert obj is not None insert_observer_for_special_module( obj, self.modules, prepare_custom_config_dict, qconfig, node) insert_observer_for_output_of_the_node( node, obj, qconfig, self.modules, model, pattern, model_device, self.activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set, matched_nodes) else: env[node.name] = observed_graph.node_copy(node, load_arg) insert_observer_for_input_arg_of_observed_node( node, observed_node_names_set, quants, model_device, model, self.activation_post_process_map, env, observed_graph, load_arg) model = GraphModule(model, observed_graph) self.save_state(model) model = mark_observed_module(model) return model
def insert_observers_for_model( model: GraphModule, modules: Dict[str, torch.nn.Module], matches: Dict[str, MatchResult], qconfig_map: Dict[str, QConfigAny], graph: Graph, prepare_custom_config_dict: Dict[str, Any], equalization_config_map: Dict[str, Any], input_quantized_idxs: List[int], output_quantized_idxs: List[int], ) -> Optional[Node]: """ Inserts observers, using the following high level algorithm: For each node in the graph: 1. determine the target dtype of this node in the quantized graph, and save it for future steps 2. determine the target dtype or all args and kwargs of this node 3. if any arg or kwarg's target dtype does not match the current node's dtype, insert an observer 4. if the current node needs an output observer, insert it For example: - starting graph: x0 -> linear -> x1 - observed graph after processing x0: x0(fp32) - observed graph after processing linear: x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) - observed graph after processing x1: x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1 After a node is processed, the naive observer placement is guaranteed to be complete for that node and all of its predecessors. There can be future passes which optimize the graph by deduplicating observers, etc. """ node_name_to_target_dtype: Dict[str, Any] = {} cache_for_no_tensor_check: Dict[Node, bool] = dict() inputs_seen_counter = 0 outputs_seen_counter = 0 results_node = None # first, populate the dtype map based only on qconfig and qhandler # this assumes: # graph inputs are fp32 by default, and int8 where overriden # other nodes output dtype is specified by the qconfig modules = dict(model.named_modules(remove_duplicate=False)) for node in model.graph.nodes: root_node, matched_nodes, pattern, qhandler, qconfig = matches.get( node.name, (None, None, None, None, None)) node_name_to_target_dtype[node.name] = get_target_activation_dtype_for_node( node, qconfig, inputs_seen_counter, outputs_seen_counter, input_quantized_idxs, output_quantized_idxs, qhandler, modules, cache_for_no_tensor_check) # Second, for nodes with known input dtypes, propagate them throughout the # graph. For example, if there is a call such as # x1 = x0.masked_fill(mask, 1) # we propagate the type of mask to be torch.bool propagate_dtypes_for_known_nodes( model.graph, node_name_to_target_dtype, matches) # After this point, the current node and all of its arguments # have a dtype assigned. Now, we insert observers for inputs # of this node (if needed for this node), and the output of this node # (if needed for this node). # Since we are mutating the graph as we go, we iterate over the original # nodes before observer insertion, instead of model.graph.nodes. nodes_before_observation = list(model.graph.nodes) for node in nodes_before_observation: if node.op == 'placeholder': # if a graph input is in fp32, it does not need observation # if a graph input is in int8, we assume the observation happens # outside of the graph, and no additional observation is needed pass elif node.op in ('call_module', 'call_method', 'call_function', 'output'): # check for matches root_node, matched_nodes, pattern, qhandler, qconfig = matches.get( node.name, (None, None, None, None, None)) equalization_qconfig = equalization_config_map.get(node.name, None) this_node_dtype = node_name_to_target_dtype[node.name] output_not_a_tensor = this_node_dtype is None # TODO(future PR): consider stopping matching getitem is_getitem = node.op == 'call_function' and \ node.target == operator.getitem skip_inserting_observers = ( (qconfig is None) or output_not_a_tensor or is_getitem ) and (not node.op == 'output') if not skip_inserting_observers: modules = dict(model.named_modules(remove_duplicate=False)) if node.op != 'output': # This is currently only used for equalization. # Checks if the current node is in a branch in which the two # first layers are both being quantized. # # ex. conv2 # / # x -> conv1 # # If this is the case, we will not apply equalization to the # initial two layers. is_quantized_branch = False if ( len(node.args) > 0 and isinstance(node.args[0], Node) and len(node.args[0].users) > 1 ): for user in node.args[0].users: # Checks if there exists another user being quantized is_user_quantized = ( qconfig_map.get(user.name, None) is not None or (user.op == 'call_module' and isinstance(modules[str(user.target)], ObserverBase)) ) if user != node and is_user_quantized: is_quantized_branch = True # this modifies node inplace maybe_insert_input_observers_for_node( node, qconfig, model, modules, graph, node_name_to_target_dtype, qhandler, prepare_custom_config_dict) # Insert equalization input observers if needed maybe_insert_input_equalization_observers_for_node( node, equalization_qconfig, model, modules, graph, node_name_to_target_dtype, is_quantized_branch) is_last_node_of_pattern = root_node is node is_general_tensor_value_op = \ (qhandler is not None and qhandler.is_general_tensor_value_op()) is_general_tensor_shape_op = \ (qhandler is not None and qhandler.is_general_tensor_shape_op()) if is_last_node_of_pattern and not is_general_tensor_shape_op: # this returns the new observer node if it was needed maybe_output_obs_node = maybe_insert_output_observer_for_node( node, model, modules, graph, matches, node_name_to_target_dtype, pattern, qhandler) if maybe_output_obs_node is not None: # Update users of original node to use the output observer # instead. For example, change # # next_node # / # cur_node -> obs # # to # # next_node # / # cur_node -> obs # # We need to save orig users before updating uses because # the list of users will change as we update uses orig_users = list(node.users.keys()) for user_node in orig_users: if user_node is maybe_output_obs_node: continue user_node.replace_input_with(node, maybe_output_obs_node) # for general tensor value ops, we modify the graph # to make all inputs and outputs use the first input's # observer if is_general_tensor_value_op: if not maybe_make_input_output_share_observers(node, model, modules): remove_output_observer(node, model, modules) if isinstance(qhandler, CustomModuleQuantizeHandler): swap_custom_module_to_observed(node, qconfig, modules, prepare_custom_config_dict) else: # output maybe_insert_observers_before_graph_output( node, output_quantized_idxs, node_name_to_target_dtype, qconfig_map, model, modules, graph) # # After this point, the current node has input and output observers # that it needs for itself inserted. # # increment the counters, so future inputs and outputs are assigned # correct dtypes if node.op == 'placeholder': inputs_seen_counter += 1 elif node.op == 'output': outputs_seen_counter += 1 results_node = node return results_node
def add_loggers_to_model( gm: GraphModule, node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]], node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]], logger_cls: Callable, model_name: str, ) -> GraphModule: """ Takes the graph of gm, adds loggers to the output of each node in nodes_to_instrument. Returns a GraphModule with the new graph. """ new_graph = Graph() env: Dict[str, Any] = {} modules = dict(gm.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in gm.graph.nodes: if node.op == 'output': new_graph.output(map_arg(node.args[0], load_arg)) continue if ((node in node_to_instrument_inputs_to_ref_node_name) or (node in node_to_instrument_outputs_to_ref_node_name)): fqn = _maybe_get_fqn(node, gm) if node in node_to_instrument_inputs_to_ref_node_name: ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[ node] # Ops such add and mul are special because either # one or two of the first two arguments can be tensors, # and if one argument is a tensor it can be first or # second (x + 1 versus 1 + x). arg_indices_to_log = get_arg_indices_of_inputs_to_log(node) for node_arg_idx in arg_indices_to_log: node_arg = node.args[node_arg_idx] if type(node_arg) == Node: # create a single input logger prev_node = env[node_arg.name] env[node_arg.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, ref_node_type, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0, index_of_arg=node_arg_idx, fqn=fqn) elif type( node_arg ) == torch.fx.immutable_collections.immutable_list: # create N input loggers, one for each node for arg_idx, arg in enumerate(node_arg): prev_node = env[arg.name] env[prev_node.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, ref_node_type, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx, index_of_arg=node_arg_idx, fqn=fqn) else: pass # ensure env is populated with base node # Note: runs for both inputs and outputs env[node.name] = new_graph.node_copy(node, load_arg) if node in node_to_instrument_outputs_to_ref_node_name: ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[ node] # add the logger after the base node env[node.name] = _insert_logger_after_node( env[node.name], gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, ref_node_type, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0, index_of_arg=0, fqn=fqn) else: env[node.name] = new_graph.node_copy(node, load_arg) new_gm = GraphModule(gm, new_graph) return new_gm
def prepare( model: GraphModule, qconfig_dict: Any, node_name_to_scope: Dict[str, Tuple[str, type]], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, is_standalone_module: bool = False) -> ObservedGraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. How the standalone module is observed is specified by `input_quantized_idxs` and `output_quantized_idxs` in the prepare_custom_config for the standalone module Args: node_name_to_scope: mapping from node name to the scope of the module which contains the node. The scope is a tuple of fully qualified path of the module and the type of the module Returns: model(GraphModule): prepared standalone module attributes: _standalone_module_input_quantized_idxs(List[Int]): a list of indexes for the graph input that is expected to be quantized, same as input_quantized_idxs configuration provided for the standalone module _standalone_module_output_quantized_idxs(List[Int]): a list of indexs for the graph output that is quantized same as input_quantized_idxs configuration provided for the standalone module """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} if equalization_qconfig_dict is None: equalization_qconfig_dict = {} if backend_config_dict is None: backend_config_dict = get_fbgemm_backend_config_dict() validate_backend_config_dict(backend_config_dict) additional_quant_patterns = \ prepare_custom_config_dict.get("additional_quant_pattern", {}) # mapping from a tuple of nodes in reverse order to uninitialized # QuantizeHandler subclass. For example, # { # # match a single node # (<class 'torch.nn.modules.conv.Conv3d'>: # <class 'torch.quantization.fx.quantize.ConvRelu'>), # # match multiple nodes in reverse order # ((<function relu at 0x7f766a7360d0>, <built-in function add>): # <class 'torch.quantization.fx.quantize.Add'>), # } quant_patterns = backend_config_dict["quant_patterns"] patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict( quant_patterns, additional_quant_patterns) convert_dict_to_ordered_dict(qconfig_dict) convert_dict_to_ordered_dict(equalization_qconfig_dict) flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: additional_qat_module_mapping = prepare_custom_config_dict.get( "additional_qat_module_mapping", {}) qat_swap_modules(model, additional_qat_module_mapping) qconfig_dict = update_qconfig_for_qat(qconfig_dict, additional_qat_module_mapping) qconfig_dict = update_qconfig_for_fusion(model, qconfig_dict) equalization_qconfig_dict = update_qconfig_for_fusion(model, equalization_qconfig_dict) # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } modules = dict(model.named_modules()) # fill qconfig_map, a map from node name to qconfig, used in find_matches equalization_qconfig_map = generate_qconfig_map(model, modules, model.graph, equalization_qconfig_dict, node_name_to_scope) qconfig_map = generate_qconfig_map(model, modules, model.graph, qconfig_dict, node_name_to_scope) # match the patterns that will get quantized standalone_module_name_configs = prepare_custom_config_dict.get( "standalone_module_name", []) standalone_module_class_configs = prepare_custom_config_dict.get( "standalone_module_class", []) standalone_module_names = [config[0] for config in standalone_module_name_configs] standalone_module_classes = [config[0] for config in standalone_module_class_configs] custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") matches = find_matches( model.graph, modules, patterns, qconfig_map, standalone_module_names, standalone_module_classes, custom_module_classes) input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) run_prepare_fx_on_standalone_modules( model, modules, matches, prepare_custom_config_dict) result_node = insert_observers_for_model( model, modules, matches, qconfig_map, model.graph, prepare_custom_config_dict, equalization_qconfig_map, input_quantized_idxs, output_quantized_idxs) save_state(model, qconfig_map, node_name_to_scope, patterns, prepare_custom_config_dict, equalization_qconfig_map) preserved_attributes = set(prepare_custom_config_dict.get("preserved_attributes", [])) model = ObservedGraphModule(model, model.graph, preserved_attributes) if is_standalone_module: assert result_node is not None assert isinstance(result_node.args[0], Node), \ "standalone module only supports returning simple value currently"\ "(not tuple, dict etc.)" # these inputs are observed in parent # converting List[int] to Tensor since module attribute is # Union[Tensor, Module] model._standalone_module_input_quantized_idxs = \ torch.tensor(input_quantized_idxs) model._standalone_module_output_quantized_idxs = torch.tensor(output_quantized_idxs) return model
def _convert_do_not_use( model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False, _remove_qconfig_flag: bool = True) -> QuantizedGraphModule: """ We will convert an observed model (a module with observer calls) to a reference quantized model, the rule is simple: 1. for each observer module call in the graph, we'll convert it to calls to quantize and dequantize functions based on the observer instance 2. for weighted operations like linear/conv, we need to convert them to reference quantized module, this requires us to know whether the dtype configured for the weight is supported in the backend, this is done in prepare step and the result is stored in observed_node_names, we can decide whether we need to swap the module based on this set standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state( model) qconfig_map: Dict[ str, QConfigAny] = model._qconfig_map # type: ignore[assignment] assert is_reference, "_convert_do_not_use only supports reference option" # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } # We use remove_duplicate=False here because torch.cat uses # the same activation_post_process module instance but different names modules = dict(model.named_modules(remove_duplicate=False)) custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") matches = find_matches(model.graph, modules, patterns, qconfig_map, custom_module_classes=custom_module_classes) if model._equalization_qconfig_map is not None: # If we want to do equalization then do the following: # Calculate the equalization scale, update the observers with the scaled # inputs, and scale the weight weight_eq_obs_dict = update_obs_for_equalization(model, modules) convert_eq_obs(model, modules, weight_eq_obs_dict) graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def replace_observer_with_quantize_dequantize_node( graph: Graph, node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] root_module = modules[""] if observer_module.dtype == torch.float32: # remove the node for now # TODO: support dynamic quant with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) elif observer_module.dtype in [ torch.quint8, torch.qint8, torch.float16 ]: node_type, quantize_op, qparams = get_quantize_node_info( observer_module) # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( root_module, graph, key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node, )) node.replace_all_uses_with(dequantized_node) graph.erase_node(node) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) for node in list(model.graph.nodes): if node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: # Inputs are assumed to be quantized if the user specifid the # input_quantized_idxs override. # TODO: remove the quantize node for the placeholder raise Exception("input_quantized_idxs is not supported yet") elif node.op == "output": cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: # Result are kept quantized if the user specified the # output_quantized_idxs override. # TODO: remove dequantize node if any raise Exception("output_quantized_idxs is not supported yet") elif node.op == "call_module": if is_activation_post_process(modules[node.target]): replace_observer_with_quantize_dequantize_node( model.graph, node, modules) elif type(modules[node.target]) in set( WEIGHTED_MODULE_CLASSES).union(QAT_MODULE_CLASSES).union( FUSED_MODULE_CLASSES): # TODO: refactor this part to a function original_module = modules[node.target] qconfig = original_module.qconfig is_observed = node.name in observed_node_names is_weight_quantized = weight_is_statically_quantized(qconfig) # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized if qconfig is None or not is_observed or not is_weight_quantized: continue float_module = original_module fused_module = None if isinstance(original_module, QAT_MODULE_CLASSES): # case 1. converting qat module to # a float module, we need to attch # weight fake_quant to the module, # weight fake_quant is assumed to be run during # QAT so we don't need to run it again here float_module = original_module.to_float( ) # type: ignore[operator] # change qat conv to conv parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, float_module) if isinstance(float_module, torch.nn.intrinsic._FusedModule): fused_module = float_module float_module = fused_module[0] weight_post_process = original_module.weight_fake_quant else: # case 2. converting a float module/fused float module # to float module, we need to attach # weight observer to the conv module and run it # with conv weight if isinstance(original_module, torch.nn.intrinsic._FusedModule): fused_module = original_module float_module = fused_module[0] # type: ignore[index] assert qconfig is not None weight_post_process = qconfig.weight() # run weight observer weight_post_process( float_module.weight) # type: ignore[operator] weight_qparams = get_qparam_dict(weight_post_process) ref_qmodule_cls = get_static_quant_module_class( type(float_module), is_reference=True) ref_qmodule = ref_qmodule_cls.from_float( float_module, weight_qparams) if fused_module is not None: fused_module[0] = ref_qmodule else: parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, ref_qmodule) # removes qconfig and activation_post_process modules if _remove_qconfig_flag: _remove_qconfig(model) preserved_attributes = set( convert_custom_config_dict.get("preserved_attributes", [])) model = QuantizedGraphModule(model, model.graph, preserved_attributes) return model
def create_a_shadows_b( name_a: str, gm_a: GraphModule, name_b: str, gm_b: GraphModule, matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]], logger_cls: Callable, should_log_inputs: bool, node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, ) -> GraphModule: """ Creates a new GraphModule consisting of the graph of C, with the meaningful nodes of A shadowing the corresponding nodes of B. For example, Graph A: a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 Graph B: b0 -> op0_int8 -> b1 -> op1_int8 -> b2 matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} Graph C (A shadows B): / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 / / b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 In a nutshell, this function does the following for each node pair: * copies the necessary attributes and modules from gm_a to gm_b, keeping names unique * adds a dtype cast op (dequant, quant, etc) * adds a copy of node_a in gm_b's graph * adds loggers to the outputs of node_a and node_b """ if node_type_to_io_type_map is None: node_type_to_io_type_map = get_node_type_to_io_type_map() # graph_c is the graph created from copying the nodes of graph_b and inserting # the shadows with the nodes copied from graph_a graph_c = Graph() env_c: Dict[str, Any] = {} modules = dict(gm_b.named_modules()) def load_arg(a): return map_arg(a, lambda node: env_c[node.name]) start_node_b_to_matched_subgraph_a_and_name = {} end_node_b_to_matched_subgraph_a_and_name = {} for match_name, match in matched_subgraph_pairs.items(): subgraph_a, subgraph_b = match ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a) ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b) start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \ (subgraph_a, match_name, ref_node_type_a, ref_node_type_b) end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \ (subgraph_a, match_name, ref_node_type_a, ref_node_type_b) for node_b in gm_b.graph.nodes: if node_b.op == 'output': graph_c.output(map_arg(node_b.args[0], load_arg)) continue # calculate the flags to determine what to do with this node node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name if (node_b_is_start_node or node_b_is_end_node): if node_b_is_start_node: subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \ start_node_b_to_matched_subgraph_a_and_name[node_b] else: assert node_b_is_end_node subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \ end_node_b_to_matched_subgraph_a_and_name[node_b] # For both start_node and end_node verify that we know how to do # the dtype cast. If we do not, skip. node_input_type_a, node_output_type_a = \ get_node_first_input_and_output_type( subgraph_a.start_node, gm_a, logger_cls, node_type_to_io_type_map) node_input_type_b, node_output_type_b = \ get_node_first_input_and_output_type( node_b, gm_b, logger_cls, node_type_to_io_type_map) node_io_types_known_a_and_b = ( node_input_type_a != NodeInputOrOutputType.UNKNOWN and node_output_type_a != NodeInputOrOutputType.UNKNOWN and node_input_type_b != NodeInputOrOutputType.UNKNOWN and node_output_type_b != NodeInputOrOutputType.UNKNOWN) if not node_io_types_known_a_and_b: print( f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' + f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' + ', unknown dtype cast') env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) continue # If we are shadowing from fp32 to int8, we need to insert # quantize_per_tensor call with qparams from the previous node. # Only do this if we are able to infer these qparams from the graph. if (node_input_type_a == NodeInputOrOutputType.INT8 and node_input_type_b == NodeInputOrOutputType.FP32): node_a_input_qparams = get_node_input_qparams( subgraph_a.start_node, gm_a, node_type_to_io_type_map) if not node_a_input_qparams: print( f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' + f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' + ', unknown input qparams') env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) continue fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a) fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b) if node_b_is_start_node: # if necessary, log the input of node_c if should_log_inputs: if isinstance(node_b.args[0], Node): prev_node_c = env_c[node_b.args[0].name] env_c[prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', node_b.name, name_b, ref_name, ref_node_type_b, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0, index_of_arg=0, fqn=fqn_base_b) elif isinstance(node_b.args[0], list): # first, save the prev_node instances, because they # will be overwritten in the env after the first logger # is added prev_node_c_list = [ env_c[arg.name] for arg in node_b.args[0] ] for arg_idx, arg in enumerate(node_b.args[0]): prev_node_c = prev_node_c_list[arg_idx] env_c[ prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', node_b.name, name_b, ref_name, ref_node_type_b, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx, index_of_arg=0, fqn=fqn_base_b) else: # logging of inputs which are not lists is not supported yet raise AssertionError( f"type {type(node_b.args[0])} is not handled yet") # subgraph so far: # # (prev_node_c)+ -> (logger_c_input)? # Note: this if statement is always True, spelling it out to clarify code # intent. if node_b_is_start_node or node_b_is_end_node: # ensure env_c is populated with base node env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) node_c = env_c[node_b.name] # after this point, # # node_a is the original node from graph_a, with parent module gm_a # node_b is the original node from graph_b, with parent module gm_b # node_c is the copy of node_b in graph_c # # subgraph so far: # # (prev_node_c)+ -> (logger_c_input)? -> node_start_c if node_b_is_start_node: # cast dtype from the dtype of node_c's input to the dtype of # node_a's input (dequant, etc) prev_node_c = node_c.args[0] if should_log_inputs: # skip the input logger when inserting a dtype cast if isinstance(prev_node_c, Node): prev_node_c = prev_node_c.args[0] elif isinstance(prev_node_c, list): prev_node_c = [arg.args[0] for arg in prev_node_c] dtype_cast_node = _insert_dtype_cast_after_node( subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c, node_b.name + '_dtype_cast_', logger_cls, node_type_to_io_type_map) # note: not inserting to env_c because all nodes which use the dtype # casts are copied from graph_a # # subgraph so far: # # (dtype_cast_node)+ # / # (prev_node_c)+ -> (logger_c_input)? -> node_start_c # if input logging is enabled, log the input to the subgraph if should_log_inputs: # TODO: explain this ref_node_name = '' if isinstance(dtype_cast_node, Node): dtype_cast_node = _insert_logger_after_node( dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_', ref_node_name, name_a, ref_name, ref_node_type_a, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0, index_of_arg=0, fqn=fqn_base_a) input_logger: Union[Node, List[Node]] = dtype_cast_node else: assert isinstance(dtype_cast_node, list) new_loggers = [] for dtype_cast_idx, dtype_cast_node_inner in enumerate( dtype_cast_node): dtype_cast_logger = _insert_logger_after_node( dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_', ref_node_name, name_a, ref_name, ref_node_type_a, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=dtype_cast_idx, index_of_arg=0, fqn=fqn_base_a) new_loggers.append(dtype_cast_logger) dtype_cast_node = new_loggers input_logger = dtype_cast_node # subgraph so far: # # (dtype_cast_node)+ -> (logger_a_input)? # / # prev_node_c -> (logger_c_input)? -> node_start_c # hook up the new mod_a copy to be in the graph, receiving the # same inputs as mod_b does, with dtype cast to match a # Some ops, such as LSTMs, have two non-param inputs. If we have # such an op, pass the second param as well. Note: dtype casting # for the second param is not implemented yet, it can be added # later if there is a use case. node_c_second_non_param_arg = None num_non_param_args_node_a = get_number_of_non_param_args( subgraph_a.start_node, gm_a) if num_non_param_args_node_a == 2: node_c_second_non_param_arg = node_c.args[1] node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( dtype_cast_node, node_c_second_non_param_arg, subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_') env_c[node_a_shadows_c.name] = node_a_shadows_c # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown) # / # (prev_node_c)+ -> (logger_c_input)? -> node_start_c if should_log_inputs: # When we created the input logger, we left the ref_node_name # as an empty string, because the subgraph copy did not exist # yet. Now that the subgraph copy exists, we modify this name # to its true value. # Note: the alternative to this is to create the input logger # after creating the subgraph, which is slightly more # complicated. This is the lesser of two evils. # input_logger = env_c[dtype_cast_node.name] # Find the first node in the subgraph cur_node = node_a_shadows_c while cur_node.args[0] != input_logger: cur_node = cur_node.args[0] # type: ignore[assignment] if isinstance(input_logger, Node): input_logger_mod = getattr(gm_b, input_logger.name) input_logger_mod.ref_node_name = cur_node.name else: assert isinstance(input_logger, list) for input_logger_inner in input_logger: input_logger_mod = getattr(gm_b, input_logger_inner.name) input_logger_mod.ref_node_name = cur_node.name # hook up a logger to the mod_a copy env_c[node_a_shadows_c.name] = _insert_logger_after_node( env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_', node_a_shadows_c.name, name_a, ref_name, ref_node_type_a, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0, index_of_arg=0, fqn=fqn_base_a) # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a # / # (prev_node_c)+ -> (logger_c_input)? -> node_start_c if node_b_is_end_node: # hook up a logger to the mod_b copy env_c[node_b.name] = _insert_logger_after_node( env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', node_b.name, name_b, ref_name, ref_node_type_b, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0, index_of_arg=0, fqn=fqn_base_b) # subgraph so far: # # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a # / # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c # # Note: node_start_c may be the same node as node_end_c, or they # may have nodes inbetween. else: env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) gm_c = GraphModule(gm_b, graph_c) return gm_c
def remove_observers_add_loggers( gm: GraphModule, node_to_instrument_inputs_to_ref_node_name: Dict[Node, str], node_to_instrument_outputs_to_ref_node_name: Dict[Node, str], logger_cls: Callable, model_name: str, ) -> GraphModule: """ Takes the graph of gm, removes all observers, adds loggers to the output of each node in nodes_to_instrument. Returns a GraphModule with the new graph. """ new_graph = Graph() env: Dict[str, Any] = {} modules = dict(gm.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in gm.graph.nodes: if node.op == 'output': new_graph.output(map_arg(node.args[0], load_arg)) continue if node.op == 'call_module' and is_activation_post_process( modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] elif ((node in node_to_instrument_inputs_to_ref_node_name) or (node in node_to_instrument_outputs_to_ref_node_name)): if node in node_to_instrument_inputs_to_ref_node_name: ref_name = node_to_instrument_inputs_to_ref_node_name[node] if type(node.args[0]) == Node: # create a single input logger prev_node = env[node.args[0].name] env[node.args[0].name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0) elif type(node.args[0] ) == torch.fx.immutable_collections.immutable_list: # create N input loggers, one for each node for arg_idx, arg in enumerate(node.args[0]): prev_node = env[arg.name] env[prev_node.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=arg_idx) else: raise AssertionError( f"type {type(node.args[0])} is not handled yet") # ensure env is populated with base node # Note: runs for both inputs and outputs env[node.name] = new_graph.node_copy(node, load_arg) if node in node_to_instrument_outputs_to_ref_node_name: ref_name = node_to_instrument_outputs_to_ref_node_name[node] # add the logger after the base node env[node.name] = _insert_logger_after_node( env[node.name], gm, logger_cls, '_ns_logger_', node.name, model_name, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value, index_within_arg=0) else: env[node.name] = new_graph.node_copy(node, load_arg) new_gm = GraphModule(gm, new_graph) return new_gm
def convert( model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False, _remove_qconfig_flag: bool = True, convert_qconfig_dict: Dict[str, Any] = None, backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module: """ We will convert an observed model (a module with observer calls) to a reference quantized model, the rule is simple: 1. for each observer module call in the graph, we'll convert it to calls to quantize and dequantize functions based on the observer instance 2. for weighted operations like linear/conv, we need to convert them to reference quantized module, this requires us to know whether the dtype configured for the weight is supported in the backend, this is done in prepare step and the result is stored in observed_node_names, we can decide whether we need to swap the module based on this set standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model) qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment] # TODO this should be removed now that gpu support for quantization is being supported. # however in practice, as of 7/22/2021, certain functions that get called by convert expect # only cpu arguments. # As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda', # fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend. if not is_reference: model.cpu() # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } # We use remove_duplicate=False here because torch.cat uses # the same activation_post_process module instance but different names modules = dict(model.named_modules(remove_duplicate=False)) # TODO refactor this code once we update the prepare logic to have additional information on # which graph nodes have been observed and share that with convert to decide which observers to ignore. if convert_qconfig_dict: prepare_qconfig_dict: Dict[str, Dict[Any, Any]] = model._qconfig_dict # type: ignore[assignment] modules_copy = copy.deepcopy(modules) convert_dict_to_ordered_dict(convert_qconfig_dict) if model._is_qat: convert_qconfig_dict = update_qconfig_for_qat(convert_qconfig_dict, {}) convert_qconfig_dict = update_qconfig_for_fusion(model, convert_qconfig_dict) compare_prepare_convert_qconfig_dict(prepare_qconfig_dict, convert_qconfig_dict) # type: ignore[arg-type] convert_qconfig_map = generate_qconfig_map(model, modules_copy, model.graph, convert_qconfig_dict, node_name_to_scope) # check the convert_qconfig_map generated and ensure that all the values either match what was set in prepare qconfig_map # or are set to None in the convert_qconfig_map. for k, v in qconfig_map.items(): assert k in convert_qconfig_map, 'Expected key {} in convert qconfig_map'.format(k) if convert_qconfig_map[k] is not None: assert qconfig_equals(v, convert_qconfig_map[k]), 'Expected k {} to have the same value in prepare qconfig_dict \ and convert qconfig_dict, found {} updated to {}.'.format(k, v, convert_qconfig_map[k]) qconfig_map = convert_qconfig_map custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {}) if model._equalization_qconfig_map is not None: # If we want to do equalization then do the following: # Calculate the equalization scale, update the observers with the scaled # inputs, and scale the weight weight_eq_obs_dict = update_obs_for_equalization(model, modules) convert_eq_obs(model, modules, weight_eq_obs_dict) # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops run_weight_observers(model) graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) # TODO: move this outside of this function def replace_observer_with_quantize_dequantize_node( model: torch.nn.Module, graph: Graph, node: Node, modules: Dict[str, torch.nn.Module], node_name_to_scope: Dict[str, Tuple[str, type]], qconfig_map: Dict[str, QConfigAny]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, qconfig_map) observer_module = modules[node.target] maybe_quantize_node_info = get_quantize_node_info(observer_module) # Skip replacing observers to quant/dequant nodes if the qconfigs of all # consumers and producers of this observer are None skip_replacement = all([ has_none_qconfig(n, qconfig_map) for n in list(node.args) + list(node.users.keys())]) if skip_replacement or maybe_quantize_node_info is None: # didn't find correponding quantize op and info for the observer_module # so we just remove the observer with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) else: # otherwise, we can convert the observer moduel call to quantize/dequantize node node_type, quantize_op, qparams = maybe_quantize_node_info # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): # TODO: we can add the information of whether a value needs to # be registered as an attribute in qparams dict itself if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) node.replace_all_uses_with(dequantized_node) graph.erase_node(node) # this is a temporary hack for custom module, we may want to implement # this properly after the custom module class design is finalized def replace_observer_with_dequantize_node(node: Node, graph: Graph): call_custom_module_node = node.args[0] assert isinstance(call_custom_module_node, Node), \ f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" node.replace_all_uses_with(call_custom_module_node) graph.erase_node(node) insert_dequantize_node(call_custom_module_node, graph) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) if backend_config_dict is None: backend_config_dict = get_native_backend_config_dict() root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config_dict) # convert tuples so that it can work with isinstance(module, tuple_of_classes) root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) qat_module_classes = get_qat_module_classes(backend_config_dict) fused_module_classes = get_fused_module_classes(backend_config_dict) statically_quantized_custom_module_nodes: Set[Node] = set() for node in list(model.graph.nodes): if node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: # Inputs are assumed to be quantized if the user specifid the # input_quantized_idxs override. # we need to dequantize the inputs since all operators took # floating point inputs in reference quantized models insert_dequantize_node(node, model.graph) elif node.op == "output": # If the argument is empty we don't need to do anything if len(output_quantized_idxs) == 0: continue # Result are kept quantized if the user specified the # output_quantized_idxs override. # Remove the dequantize operator for the node in the end if any return_node = node output = node.args[0] # outputs can be Node, list, tuple, dict, other cases are not supported yet if isinstance(output, (list, tuple)): for idx in output_quantized_idxs: maybe_recursive_remove_dequantize(output[idx], return_node, model.graph) elif isinstance(output, (Node, dict)): # we treat dict as a single argument currently, but it can be extended # to support {"key": dtype} after we change output_quantized_idxs to # dict if 0 in output_quantized_idxs: maybe_recursive_remove_dequantize(output, return_node, model.graph) else: warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}") elif node.op == "call_module": if is_activation_post_process(modules[node.target]): observed_node = node.args[0] if observed_node in statically_quantized_custom_module_nodes: replace_observer_with_dequantize_node(node, model.graph) else: replace_observer_with_quantize_dequantize_node( model, model.graph, node, modules, node_name_to_scope, qconfig_map) elif is_observed_standalone_module(modules[node.target]): convert_standalone_module( node, modules, model, is_reference, backend_config_dict) elif type(modules[node.target]) in set( root_module_classes).union(qat_module_classes).union(fused_module_classes): # extra check for fused module classes to make sure they are fused module classes # of target modules if type(modules[node.target]) in fused_module_classes and \ type(modules[node.target][0]) not in root_module_classes: continue convert_weighted_module( node, modules, observed_node_names, qconfig_map, backend_config_dict) elif type(modules[node.target]) in custom_module_classes: convert_custom_module( node, model.graph, modules, custom_module_class_mapping, statically_quantized_custom_module_nodes) preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", [])) model = QuantizedGraphModule(model, copy.deepcopy(model.graph), preserved_attributes) # remove deadcode after converting observers to quant/dequant ops model.graph.eliminate_dead_code() model.recompile() # TODO: maybe move this to quantize_fx.py if not is_reference: model = duplicate_dequantize_node(model) model = duplicate_quantize_dynamic_node(model) model = lower_to_fbgemm(model, qconfig_map, node_name_to_scope) model = remove_quant_dequant_pairs(model) model = remove_extra_dequantize(model) # TODO: this looks hacky, we want to check why we need this and see if we can # remove this # removes qconfig and activation_post_process modules if _remove_qconfig_flag: _remove_qconfig(model) return model
def create_a_shadows_b( name_a: str, gm_a: GraphModule, name_b: str, gm_b: GraphModule, matched_subgraph_pairs: Dict[str, Tuple[Tuple[Node, Node], Tuple[Node, Node]]], logger_cls: Callable, ) -> GraphModule: """ Creates a new GraphModule consisting of the graph of C, with the meaningful nodes of A shadowing the corresponding nodes of B. For example, Graph A: a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2 Graph B: b0 -> op0_int8 -> b1 -> op1_int8 -> b2 matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)} Graph C (A shadows B): / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1 / / b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1 In a nutshell, this function does the following for each node pair: * copies the necessary attributes and modules from gm_a to gm_b, keeping names unique * adds a dtype cast op (dequant, quant, etc) * adds a copy of node_a in gm_b's graph * adds loggers to the outputs of node_a and node_b """ # graph_c is the graph created from copying the nodes of graph_b and inserting # the shadows with the nodes copied from graph_a graph_c = Graph() env_c: Dict[str, Any] = {} modules = dict(gm_b.named_modules()) def load_arg(a): return map_arg(a, lambda node: env_c[node.name]) node_b_to_matched_subgraph_a = {} for match_name, match in matched_subgraph_pairs.items(): (node_start_a, node_end_a), (node_start_b, node_end_b) = match assert node_start_b is node_end_b, \ "Shadowing subgraphs of B with multiple nodes is not yet handled." node_b_to_matched_subgraph_a[node_end_b] = (node_start_a, node_end_a) for node_b in gm_b.graph.nodes: if node_b.op == 'output': graph_c.output(map_arg(node_b.args[0], load_arg)) continue if node_b.op == 'call_module' and is_activation_post_process( modules[node_b.target]): # remove activation post process node env_c[node_b.name] = env_c[node_b.args[0].name] # type: ignore elif node_b in node_b_to_matched_subgraph_a: node_start_a, node_end_a = node_b_to_matched_subgraph_a[node_b] if False: print('b') print_node(node_b) print('a') print_node(node_start_a) print_node(node_end_a) # ensure env_c is populated with base node env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) node_c = env_c[node_b.name] # after this point, # # node_a is the original node from graph_a, with parent module gm_a # node_b is the original node from graph_b, with parent module gm_b # node_c is the copy of node_b in graph_c # # subgraph so far: # # prev_node_c -> node_c # cast dtype from the dtype of node_c's input to the dtype of # node_a's input (dequant, etc) dtype_cast_node = _insert_dtype_cast_after_node( node_start_a, node_c, node_c.args[0], gm_a, gm_b, graph_c, node_b.name + '_dtype_cast_') env_c[dtype_cast_node.name] = dtype_cast_node # subgraph so far: # # dtype_cast_node # / # prev_node_c -> node_c # hook up the new mod_a copy to be in the graph, receiving the # same inputs as mod_b does, with dtype cast to match a node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( env_c[dtype_cast_node.name], node_start_a, node_end_a, gm_a, gm_b, node_c.name + '_shadow_copy_') env_c[node_a_shadows_c.name] = node_a_shadows_c # subgraph so far: # # dtype_cast_node --> subgraph_a_copy(args/kwargs not shown) # / # prev_node_c -> node_c # hook up a logger to the mod_b copy env_c[node_b.name] = _insert_logger_after_node( env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', name_b) # subgraph so far: # # dtype_cast_node --> subgraph_a_copy # / # prev_node_c -> node_c --> logger_c # hook up a logger to the mod_a copy # Note: we pass node_b.name to this logger, for easy matching later env_c[node_a_shadows_c.name] = _insert_logger_after_node( env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_', name_a, node_b.name) # subgraph so far: # # dtype_cast_node --> subgraph_a_copy --> logger_a # / # prev_node_c -> node_c --> logger_c else: env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) gm_c = GraphModule(gm_b, graph_c) return gm_c
def convert(model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False, _remove_qconfig_flag: bool = True) -> QuantizedGraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} patterns, node_name_to_scope, prepare_custom_config_dict = restore_state( model) qconfig_map: Dict[ str, QConfigAny] = model._qconfig_map # type: ignore[assignment] # TODO this should be removed now that gpu support for quantization is being supported. # however in practice, as of 7/22/2021, certain functions that get called by convert expect # only cpu arguments. # As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda', # fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend. if not is_reference: model.cpu() # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } # We use remove_duplicate=False here because torch.cat uses # the same activation_post_process module instance but different names modules = dict(model.named_modules(remove_duplicate=False)) custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") matches = find_matches(model.graph, modules, patterns, qconfig_map, custom_module_classes=custom_module_classes) if model._equalization_qconfig_map is not None: # If we want to do equalization then do the following: # Calculate the equalization scale, update the observers with the scaled # inputs, and scale the weight weight_eq_obs_dict = update_obs_for_equalization(model, modules) convert_eq_obs(model, modules, weight_eq_obs_dict) # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops run_weight_observers(model) quantized_graph = Graph() env: Dict[str, Dict[Optional[torch.dtype], Node]] = defaultdict( lambda: defaultdict(Node)) # type: ignore[arg-type] graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def load_non_quantized(n: Node) -> Node: assert n.name in env, \ 'trying to load float node but did not find ' + \ 'node:' + n.name + \ ' in env: ' + \ str(env) dtype_to_node = env[n.name] if torch.float in dtype_to_node: return dtype_to_node[torch.float] elif None in dtype_to_node: return dtype_to_node[None] else: quantized_node = None for dtype in [torch.quint8, torch.qint8, torch.float16]: if dtype in dtype_to_node: quantized_node = dtype_to_node[dtype] break assert quantized_node is not None, "Did not find a supported quantized dtype:{}".format( dtype_to_node) env[n.name][torch.float] = Proxy(quantized_node).dequantize().node return env[n.name][torch.float] def load_quantized(dtype: torch.dtype): def load_quantized_impl(n: Node): assert n.name in env, \ 'trying to load quantized node but did not find node:' + \ n.name + ' in environment:' + str(env) dtype_to_node = env[n.name] local_dtype: Optional[torch.dtype] = dtype if local_dtype == torch.float and local_dtype not in dtype_to_node: local_dtype = None if local_dtype in [torch.float, None]: return load_non_quantized(n) assert local_dtype in dtype_to_node, f'Expecting {dtype} in {dtype_to_node}' return dtype_to_node[local_dtype] return load_quantized_impl def load_x(n: Node) -> Node: assert n.name in env, \ 'node ' + n.name + ' does not exist in environment' dtype_to_node = env[n.name] dtypes = [ torch.quint8, torch.qint8, torch.float16, torch.float32, None ] for dtype in dtypes: if dtype in dtype_to_node: return dtype_to_node[dtype] raise Exception( f'dtype {dtype} not found in environment: {dtype_to_node} for node {n.name}' ) def load_arg( quantized: Optional[Union[List[int], Dict[int, torch.dtype], torch.dtype, Tuple[int, ...]]] ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, torch.dtype, list or tuple - if quantized is None, then we'll load the node as long as it exists - if quantized is a dtype, then all args will be quantized to the specific dtype - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=torch.float) - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized to torch.quint8 Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, dict, torch.dtype)), type(quantized) if isinstance(quantized, (tuple, list, dict)) and len(quantized) == 0: # empty tuple or list means nothing is quantized quantized = torch.float def load_arg_impl(arg_or_args): # we'll update the format of `quantized` # to better match arg_or_args updated_quantized: Optional[Union[List[int], torch.dtype, Dict[int, torch.dtype], Tuple[int, ...]]] = quantized if isinstance(quantized, (tuple, list)) and \ len(quantized) == 1 and isinstance(arg_or_args, Node): # when argument is one Node instead of tuple, we just need to check # 0 is in the quantized list if 0 in quantized: updated_quantized = torch.quint8 if updated_quantized is None: return map_arg(arg_or_args, load_x) if isinstance(updated_quantized, torch.dtype): return map_arg(arg_or_args, load_quantized(updated_quantized)) elif isinstance(updated_quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in updated_quantized: # Currently it's hardcoded to torch.quint8, we can extend this # in the future to support all quantized # dtypes loaded_args.append( map_arg(a, load_quantized(torch.quint8))) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) elif isinstance(updated_quantized, dict): loaded_args = [] for i, a in enumerate(arg_or_args): if i in updated_quantized: loaded_args.append( map_arg(a, load_quantized(updated_quantized[i]))) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def node_arg_is_quantized(node_arg: Any) -> bool: if isinstance(node_arg, Node): assert node_arg.name in env, \ 'Expecting node_arg to be in the environment' if node_arg.name in env: dtype_to_node = env[node_arg.name] return any([ x in dtype_to_node for x in [torch.quint8, torch.qint8, torch.float16] ]) else: return False elif isinstance(node_arg, list): quantized = map(node_arg_is_quantized, node_arg) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") else: return False def is_output_quantized(node: Node, obj: QuantizeHandler, qconfig: QConfigAny, modules: Dict[str, torch.nn.Module]) -> bool: """ Check if output node is quantized or not """ assert modules is not None # for some ops the output is quantized only when `is_reference` is True # and when `is_reference` is False, it has limited qconfig # support, for example `add` # ideally this check should not happen here, it should happen either in # prepare or during lowering, we don't need this check # after the default path is changed to produce reference patterns quantized = obj.is_output_quantized(qconfig) # Need to get correct quantized/non-quantized state forn the output # of FixedQParamsQuantizeHandler # TODO: we may want to try to remove the special case here # as well if obj.should_mark_output_quantized_from_input_quantized_status( qconfig): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled' # TODO: need to extend this to consider all relevant args instead of just arg[0] quantized = node_arg_is_quantized(node.args[0]) # the output is unquantized if the node is not a CopyNode # or the activation is not statically quantized if not activation_is_statically_quantized(qconfig) or \ not obj.input_output_observed(): quantized = False if node_return_type_is_int(node): quantized = False return quantized def insert_quantize_node(node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Given a activation_post_process module call node, insert a quantize node""" assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float32: # copy the observer for fp32 dtype env[node.name][torch.float] = quantized_graph.node_copy( node, load_non_quantized) elif isinstance(prev_node, Node) and prev_node.name in env: # if previous node is already quantized, we'll just remove the # activation_post_process prev_dtype_to_node: Dict[Optional[torch.dtype], Node] = env[prev_node.name] current_dtype: Optional[ torch. dtype] = observer_module.dtype # type: ignore[assignment] if current_dtype in prev_dtype_to_node: env[node. name][current_dtype] = prev_dtype_to_node[current_dtype] else: root_module = modules[""] assert isinstance(prev_node, Node) observer_dtype: torch.dtype = observer_module.dtype # type: ignore[assignment] env[node.name][observer_dtype] = \ quantize_node( load_non_quantized(prev_node), observer_module, node, modules, quantized_graph, node_name_to_scope, is_input=True) else: # replace activation post process with quantization ops root_module = modules[""] assert isinstance(node.args[0], Node) dtype: torch.dtype = observer_module.dtype # type: ignore[assignment] env[node.name][dtype] = \ quantize_node( load_non_quantized(node.args[0]), observer_module, node, modules, quantized_graph, node_name_to_scope, is_input=True) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) for node in model.graph.nodes: if node.op == "output": cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: # Result are kept quantized if the user specified the # output_quantized_idxs override. graph_output = map_arg(node.args[0], load_x) else: graph_output = map_arg(node.args[0], load_non_quantized) quantized_graph.output(graph_output) continue root_node, matched, matched_pattern, obj, qconfig = \ matches.get(node.name, (None, None, None, None, None)) if root_node is node: is_observed_standalone_module_node = ( node.op == 'call_module' and is_observed_standalone_module(modules[node.target])) if qconfig is None and not is_observed_standalone_module_node: result = quantized_graph.node_copy(node, load_non_quantized) quantized = False else: assert obj is not None # We will get whether the output is quantized or not before # convert for standalone module and after convert # for non-standalone module, since _standalone_module_output_quantized_idxs # is only available in observed standalone module if is_observed_standalone_module_node: out_quant_idxs = modules[ node. target]._standalone_module_output_quantized_idxs.tolist( ) # noqa: B950 assert len( out_quant_idxs ) <= 1, "Currently standalone only support one output" quantized = 0 in out_quant_idxs qconfig = qconfig_map[node.name] # Note: load_arg can be overwritten in the convert method when used to # create Node in graph result = obj.convert( node, qconfig, modules, quantized_graph, node_name_to_scope, load_arg, is_reference=is_reference, convert_custom_config_dict=convert_custom_config_dict) if not is_observed_standalone_module_node: quantized = is_output_quantized(node, obj, qconfig, modules) if quantized: env[node.name][activation_dtype(qconfig)] = result else: env[node.name][torch.float] = result continue elif root_node is not None: if qconfig is None: # This branch is hit if all of these conditions are met: # 1. we are in a fusion pattern of multiple nodes (i.e. add-relu) # 2. the current node is not the "root_node" of the pattern # 3. quantization for this pattern is disabled # # In this case, we need to make sure to populate the env with # intermediate nodes manually, because the QuantizeHandler.convert # function will not be called. result = quantized_graph.node_copy(node, load_non_quantized) env[node.name][torch.float] = result continue # handle activation post process calls if node.op == 'call_module' and \ is_activation_post_process(modules[node.target]): insert_quantize_node(node, modules) elif node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: env[node.name][torch.quint8] = quantized_graph.node_copy( node, load_non_quantized) else: env[node.name][torch.float] = \ quantized_graph.node_copy(node, load_non_quantized) else: # copy quantized or non-quantized node # get_tensor_info_node like shape works for both # quantized and non-quantized input and output a non-Tensor # (we use None for dtype currently for non-Tensors) if is_get_tensor_info_node(node): env[node.name][None] = \ quantized_graph.node_copy(node, load_x) else: env[node.name][torch.float] = \ quantized_graph.node_copy(node, load_non_quantized) # remove activation post process act_post_process_removed_graph = Graph() remove_env: Dict[str, Node] = {} def load_arg_remove(a: Argument) -> Argument: return map_arg(a, lambda node: remove_env[node.name]) for node in quantized_graph.nodes: if node.op == 'output': act_post_process_removed_graph.output( map_arg(node.args[0], load_arg_remove)) continue if node.op == 'call_module' and \ is_activation_post_process(modules[node.target]): # remove activation post process node remove_env[node.name] = remove_env[node.args[0].name] else: remove_env[node.name] = act_post_process_removed_graph.node_copy( node, load_arg_remove) # removes qconfig and activation_post_process modules if _remove_qconfig_flag: _remove_qconfig(model) preserved_attributes = set( convert_custom_config_dict.get("preserved_attributes", [])) model = QuantizedGraphModule(model, act_post_process_removed_graph, preserved_attributes) if not is_reference: model = fold_weight(model, node_name_to_scope) model = lower_to_fbgemm(model) return model
def _convert_do_not_use( model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False, _remove_qconfig_flag: bool = True, backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module: """ We will convert an observed model (a module with observer calls) to a reference quantized model, the rule is simple: 1. for each observer module call in the graph, we'll convert it to calls to quantize and dequantize functions based on the observer instance 2. for weighted operations like linear/conv, we need to convert them to reference quantized module, this requires us to know whether the dtype configured for the weight is supported in the backend, this is done in prepare step and the result is stored in observed_node_names, we can decide whether we need to swap the module based on this set standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model) qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment] assert is_reference, "_convert_do_not_use only supports reference option" # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } # We use remove_duplicate=False here because torch.cat uses # the same activation_post_process module instance but different names modules = dict(model.named_modules(remove_duplicate=False)) custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") if model._equalization_qconfig_map is not None: # If we want to do equalization then do the following: # Calculate the equalization scale, update the observers with the scaled # inputs, and scale the weight weight_eq_obs_dict = update_obs_for_equalization(model, modules) convert_eq_obs(model, modules, weight_eq_obs_dict) graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def replace_observer_with_quantize_dequantize_node(graph: Graph, node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] root_module = modules[""] if observer_module.dtype == torch.float32: # remove the node for now # TODO: support dynamic quant with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) elif observer_module.dtype in [torch.quint8, torch.qint8, torch.float16]: node_type, quantize_op, qparams = get_quantize_node_info(observer_module) # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value(root_module, graph, key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) node.replace_all_uses_with(dequantized_node) graph.erase_node(node) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) if backend_config_dict is None: backend_config_dict = {} quantized_reference_module_mapping = get_quantized_reference_module_mapping(backend_config_dict) # convert tuples so that it can work with isinstance(module, tuple_of_classes) weighted_module_classes = tuple(quantized_reference_module_mapping.keys()) for node in list(model.graph.nodes): if node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: # Inputs are assumed to be quantized if the user specifid the # input_quantized_idxs override. # we need to dequantize the inputs since all operators took # floating point inputs in reference quantized models insert_dequantize_node(node, model.graph) elif node.op == "output": cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: # Result are kept quantized if the user specified the # output_quantized_idxs override. # Remove the dequantize operator in the end maybe_dequantize_node = node.args[0] if isinstance(maybe_dequantize_node, Node) and \ maybe_dequantize_node.op == "call_method" and \ maybe_dequantize_node.target == "dequantize": quantize_node = maybe_dequantize_node.args[0] maybe_dequantize_node.replace_all_uses_with(quantize_node) model.graph.erase_node(maybe_dequantize_node) elif node.op == "call_module": if is_activation_post_process(modules[node.target]): replace_observer_with_quantize_dequantize_node(model.graph, node, modules) elif is_observed_standalone_module(modules[node.target]): # TODO: move this to a separate function convert_standalone_module(node, modules, model, is_reference, backend_config_dict) elif type(modules[node.target]) in set( weighted_module_classes).union(QAT_MODULE_CLASSES).union(FUSED_MODULE_CLASSES): convert_weighted_module(node, modules, observed_node_names, quantized_reference_module_mapping) # removes qconfig and activation_post_process modules if _remove_qconfig_flag: _remove_qconfig(model) preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", [])) model = QuantizedGraphModule(model, model.graph, preserved_attributes) return model