def _convert(self, model: GraphModule, debug: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False) -> GraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module which accepts float input and produces float output. """ if convert_custom_config_dict is None: convert_custom_config_dict = {} self.restore_state(model) # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops self._run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() self.modules = dict(model.named_modules()) custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") assert self.patterns is not None matches = self._find_matches( model.graph, self.modules, self.patterns, custom_module_classes=custom_module_classes) quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \ self._find_quants(model.graph, matches) self.quantized_graph = Graph() env: Dict[str, Node] = {} quant_env: Dict[str, Node] = {} graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def load_non_quantized(n: Node) -> Node: if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find ' + \ 'node:' + n.name + \ ' in quantized or non quantized environment, env: ' + \ str(env) + ' quant_env:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n: Node) -> Node: assert n.name in quant_env, \ 'trying to load quantized node but did not find node:' + \ n.name + ' in quant environment:' + str(quant_env) return quant_env[n.name] def load_x(n: Node) -> Node: assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized: Optional[Union[List[Any], bool, Tuple[Any, ...]]] ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, list, boolean or tuple - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is None, then we'll load the node as long as it exists Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg_or_args): if quantized is None: return map_arg(arg_or_args, load_x) if isinstance(quantized, bool): return map_arg( arg_or_args, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def node_arg_is_quantized(node_arg: Any) -> bool: if isinstance(node_arg, Node): assert node_arg.name in env or node_arg.name in quant_env, \ 'Expecting node_arg to be in the environment' # there might be nodes appearing in both environemnts, but # quant_env will take precedence if node_arg.name in quant_env: return True elif node_arg.name in env: return False else: return False elif isinstance(node_arg, list): quantized = map(node_arg_is_quantized, node_arg) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") else: return False def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool: """ Check if output node is quantized or not """ assert self.modules is not None # by default the output is expected to be quantized quantized = True # Need to get correct quantized/non-quantized state for the output # of CopyNode if type(obj) in [ CopyNode, FixedQParamsOpQuantizeHandler ]: assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = node_arg_is_quantized(node.args[0]) if not activation_is_statically_quantized(qconfig) or \ not input_output_observed(obj): quantized = False return quantized def insert_quantize_node(node: Node) -> None: """ Given a activation_post_process module call node, insert a quantize node""" assert self.modules is not None assert isinstance(node.target, str) observer_module = self.modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float16: # activations are not quantized for # fp16 dynamic quantization # copy the activaiton_post_process node here # since we may need it when we insert prepack # op for weight of linear, this will be removed # later in a separate pass env[node.name] = self.quantized_graph.node_copy( node, load_non_quantized) elif isinstance(prev_node, Node) and prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the # activation_post_process quant_env[node.name] = quant_env[prev_node.name] else: # replace activation post process with quantization ops root_module = self.modules[""] assert isinstance(node.args[0], Node) quant_env[node.name] = quantize_node( root_module, self.quantized_graph, load_non_quantized(node.args[0]), observer_module) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get( "output_quantized_idxs", []) for node in model.graph.nodes: if node.op == 'output': cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: # Result are kept quantized if the user specified the # output_quantized_idxs override. graph_output = map_arg(node.args[0], load_x) else: graph_output = map_arg(node.args[0], load_non_quantized) self.quantized_graph.output(graph_output) continue root_node, matched, matched_pattern, obj, qconfig = \ matches.get(node.name, (None, None, None, None, None)) if root_node is node: if qconfig is None: result = self.quantized_graph.node_copy( node, load_non_quantized) quantized = False else: assert obj is not None is_standalone_module_node = ( node.op == 'call_module' and is_observed_standalone_module( self.modules[node.target]) # type: ignore ) result = obj.convert( self, node, load_arg, debug=debug, convert_custom_config_dict=convert_custom_config_dict) if is_standalone_module_node: quantized = False else: quantized = is_output_quantized(node, obj) if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module' and \ is_activation_post_process(self.modules[node.target]): insert_quantize_node(node) elif node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: quant_env[node.name] = \ self.quantized_graph.node_copy(node, load_non_quantized) else: env[node.name] = \ self.quantized_graph.node_copy(node, load_non_quantized) else: # copy quantized or non-quantized node env[node.name] = \ self.quantized_graph.node_copy(node, load_non_quantized) # remove activation post process act_post_process_removed_graph = Graph() env = {} def load_arg_simple(a: Argument) -> Argument: return map_arg(a, lambda node: env[node.name]) for node in self.quantized_graph.nodes: if node.op == 'output': act_post_process_removed_graph.output( map_arg(node.args[0], load_arg_simple)) continue if node.op == 'call_module' and \ is_activation_post_process(self.modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] else: env[node.name] = act_post_process_removed_graph.node_copy( node, load_arg_simple) # removes qconfig and activation_post_process modules _remove_qconfig(model) model = GraphModule(model, act_post_process_removed_graph) return model
def convert(model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False, _remove_qconfig_flag: bool = True) -> QuantizedGraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} patterns, node_name_to_scope, prepare_custom_config_dict = restore_state( model) qconfig_map: Dict[ str, QConfigAny] = model._qconfig_map # type: ignore[assignment] # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } # We use remove_duplicate=False here because torch.cat uses # the same activation_post_process module instance but different names modules = dict(model.named_modules(remove_duplicate=False)) custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") matches = find_matches(model.graph, modules, patterns, qconfig_map, custom_module_classes=custom_module_classes) quantized_graph = Graph() env: Dict[str, Tuple[Node, Optional[torch.dtype]]] = {} graph_inputs: List[str] = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) def load_non_quantized(n: Node) -> Node: assert n.name in env, \ 'trying to load float node but did not find ' + \ 'node:' + n.name + \ ' in env: ' + \ str(env) quantized_node, dtype = env[n.name] if dtype and dtype != torch.float: env[n.name] = Proxy(quantized_node).dequantize().node, torch.float return env[n.name][0] def load_quantized(n: Node) -> Node: assert n.name in env, \ 'trying to load quantized node but did not find node:' + \ n.name + ' in environment:' + str(env) quantized_node, dtype = env[n.name] assert dtype in [torch.quint8, torch.qint8, torch.float16], \ f'Expecting node {quantized_node} to be quantized but got dtype: {dtype}' return quantized_node def load_x(n: Node) -> Node: assert n.name in env, \ 'node ' + n.name + ' does not exist in environment' return env[n.name][0] def load_arg( quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, list, boolean or tuple - if quantized is None, then we'll load the node as long as it exists - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False) - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, bool)), type(quantized) if isinstance(quantized, (tuple, list)) and len(quantized) == 0: # empty tuple or list means nothing is quantized quantized = False def load_arg_impl(arg_or_args): # we'll update the format of `quantized` # to better match arg_or_args updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized if isinstance(quantized, (tuple, list)) and \ len(quantized) == 1 and isinstance(arg_or_args, Node): # when argument is one Node instead of tuple, we just need to check # 0 is in the quantized list updated_quantized = 0 in quantized if updated_quantized is None: return map_arg(arg_or_args, load_x) if isinstance(updated_quantized, bool): return map_arg( arg_or_args, load_quantized if updated_quantized else load_non_quantized) elif isinstance(updated_quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in updated_quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def node_arg_is_quantized(node_arg: Any) -> bool: if isinstance(node_arg, Node): assert node_arg.name in env, \ 'Expecting node_arg to be in the environment' if node_arg.name in env: _, dtype = env[node_arg.name] return dtype != torch.float else: return False elif isinstance(node_arg, list): quantized = map(node_arg_is_quantized, node_arg) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") else: return False def is_output_quantized(node: Node, obj: QuantizeHandler, qconfig: QConfigAny, modules: Dict[str, torch.nn.Module]) -> bool: """ Check if output node is quantized or not """ assert modules is not None # by default the output for a quantizable node is expected to be quantized quantized = True # Need to get correct quantized/non-quantized state forn the output # of FixedQParamsQuantizeHandler # TODO: we may want to try to remove the special case here # as well if obj.should_mark_output_quantized_from_input_quantized_status( qconfig): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled' # TODO: need to extend this to consider all relevant args instead of just arg[0] quantized = node_arg_is_quantized(node.args[0]) # the output is unquantized if the node is not a CopyNode # or the activation is not statically quantized if not activation_is_statically_quantized(qconfig) or \ not obj.input_output_observed(): quantized = False if node_return_type_is_int(node): quantized = False return quantized def insert_quantize_node(node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Given a activation_post_process module call node, insert a quantize node""" assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] prev_node = node.args[0] if observer_module.dtype == torch.float32: # copy the observer for fp32 dtype env[node.name] = quantized_graph.node_copy( node, load_non_quantized), torch.float elif isinstance(prev_node, Node) and prev_node.name in env: # if previous node is already quantized, we'll just remove the # activation_post_process _, prev_dtype = env[prev_node.name] current_dtype = observer_module.dtype if prev_dtype == current_dtype: env[node.name] = env[prev_node.name] else: root_module = modules[""] assert isinstance(prev_node, Node) observer_dtype: torch.dtype = observer_module.dtype # type: ignore[assignment] env[node.name] = (quantize_node(load_non_quantized(prev_node), observer_module, node, modules, quantized_graph, node_name_to_scope, is_input=True), observer_dtype) else: # replace activation post process with quantization ops root_module = modules[""] assert isinstance(node.args[0], Node) dtype: torch.dtype = observer_module.dtype # type: ignore[assignment] env[node.name] = (quantize_node(load_non_quantized(node.args[0]), observer_module, node, modules, quantized_graph, node_name_to_scope, is_input=True), dtype) # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 output_node_seen_cnt = 0 input_quantized_idxs: List[int] = prepare_custom_config_dict.get( "input_quantized_idxs", []) output_quantized_idxs: List[int] = prepare_custom_config_dict.get( "output_quantized_idxs", []) for node in model.graph.nodes: if node.op == "output": cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: # Result are kept quantized if the user specified the # output_quantized_idxs override. graph_output = map_arg(node.args[0], load_x) else: graph_output = map_arg(node.args[0], load_non_quantized) quantized_graph.output(graph_output) continue root_node, matched, matched_pattern, obj, qconfig = \ matches.get(node.name, (None, None, None, None, None)) if root_node is node: is_observed_standalone_module_node = ( node.op == 'call_module' and is_observed_standalone_module(modules[node.target])) if qconfig is None and not is_observed_standalone_module_node: result = quantized_graph.node_copy(node, load_non_quantized) quantized = False else: assert obj is not None # We will get whether the output is quantized or not before # convert for standalone module and after convert # for non-standalone module, since _standalone_module_output_quantized_idxs # is only available in observed standalone module if is_observed_standalone_module_node: out_quant_idxs = modules[ node. target]._standalone_module_output_quantized_idxs.tolist( ) # type: ignore[operator] # noqa: B950 assert len( out_quant_idxs ) <= 1, "Currently standalone only support one output" quantized = 0 in out_quant_idxs qconfig = qconfig_map[node.name] result = obj.convert( node, qconfig, modules, quantized_graph, node_name_to_scope, load_arg, is_reference=is_reference, convert_custom_config_dict=convert_custom_config_dict) if not is_observed_standalone_module_node: quantized = is_output_quantized(node, obj, qconfig, modules) if quantized: env[node.name] = result, activation_dtype(qconfig) else: env[node.name] = result, torch.float continue elif root_node is not None: if qconfig is None: # This branch is hit if all of these conditions are met: # 1. we are in a fusion pattern of multiple nodes (i.e. add-relu) # 2. the current node is not the "root_node" of the pattern # 3. quantization for this pattern is disabled # # In this case, we need to make sure to populate the env with # intermediate nodes manually, because the QuantizeHandler.convert # function will not be called. result = quantized_graph.node_copy(node, load_non_quantized) env[node.name] = result, torch.float continue # handle activation post process calls if node.op == 'call_module' and \ is_activation_post_process(modules[node.target]): insert_quantize_node(node, modules) elif node.op == 'placeholder': cur_placeholder_node_idx = placeholder_node_seen_cnt placeholder_node_seen_cnt += 1 if cur_placeholder_node_idx in input_quantized_idxs: env[node.name] = \ quantized_graph.node_copy( node, load_non_quantized), torch.quint8 else: env[node.name] = \ quantized_graph.node_copy(node, load_non_quantized), torch.float else: # copy quantized or non-quantized node # get_tensor_info_node like shape works for both # quantized and non-quantized input and output a non-Tensor # (we use None for dtype currently for non-Tensors) if is_get_tensor_info_node(node): env[node.name] = \ quantized_graph.node_copy(node, load_x), None else: env[node.name] = \ quantized_graph.node_copy(node, load_non_quantized), torch.float # remove activation post process act_post_process_removed_graph = Graph() remove_env: Dict[str, Node] = {} def load_arg_remove(a: Argument) -> Argument: return map_arg(a, lambda node: remove_env[node.name]) for node in quantized_graph.nodes: if node.op == 'output': act_post_process_removed_graph.output( map_arg(node.args[0], load_arg_remove)) continue if node.op == 'call_module' and \ is_activation_post_process(modules[node.target]): # remove activation post process node remove_env[node.name] = remove_env[node.args[0].name] else: remove_env[node.name] = act_post_process_removed_graph.node_copy( node, load_arg_remove) # removes qconfig and activation_post_process modules if _remove_qconfig_flag: _remove_qconfig(model) preserved_attributes = set( convert_custom_config_dict.get("preserved_attributes", [])) model = QuantizedGraphModule(model, act_post_process_removed_graph, preserved_attributes) if not is_reference: model = fold_weight(model, node_name_to_scope) return model