def _insert_quantize_per_tensor_node( prev_node_c: Node, node_a: Node, gm_b: GraphModule, graph_c: Graph, scale: Union[torch.Tensor, float], zero_point: Union[torch.Tensor, int], dtype_cast_name: str, ) -> Node: # copy scale scale_node_name = \ get_new_attr_name_with_prefix( node_a.name + '_input_scale_')(gm_b) setattr(gm_b, scale_node_name, scale) scale_node = graph_c.create_node('get_attr', scale_node_name, (), {}, scale_node_name) # copy zero_point zero_point_node_name = \ get_new_attr_name_with_prefix( node_a.name + '_input_zero_point_')(gm_b) setattr(gm_b, zero_point_node_name, zero_point) zero_point_node = graph_c.create_node('get_attr', zero_point_node_name, (), {}, zero_point_node_name) # create the quantize_per_tensor call return graph_c.create_node( 'call_function', torch.quantize_per_tensor, (prev_node_c, scale_node, zero_point_node, torch.quint8), {}, dtype_cast_name)
def _insert_dtype_cast_after_node( node_a: Node, node_c: Node, prev_node_c: Union[Node, List[Node]], gm_a: GraphModule, gm_b: GraphModule, graph_c: Graph, node_name_prefix: str, ) -> Union[Node, List[Node]]: """ Given a starting graph C (derived from graph B) of ... -> prev_node_c -> node_c -> ... And a corresponding related node_a, inserts the correct dtype cast node after prev_node_c to cast into the dtype expected by node_a, resulting in: dtype_cast / ... -> prev_node_c -> node_c -> ... For example, if node_c is an int8 op and node_a is an fp32 op, this function will insert a dequant. """ dtype_cast_op = None node_input_type_a = get_node_input_type(node_a, gm_a) node_input_type_c = get_node_input_type(node_c, gm_b) if node_input_type_a == NodeInputType.FP32 and node_input_type_c == NodeInputType.INT8: dtype_cast_op = torch.dequantize else: raise AssertionError( f"dtype cast from {node_input_type_c} to {node_input_type_a} needs to be implemented" ) if isinstance(prev_node_c, Node): new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) return graph_c.create_node('call_function', dtype_cast_op, (prev_node_c, ), {}, new_dtype_cast_name) elif isinstance(prev_node_c, list): results = [] for prev_node_c_inner in prev_node_c: new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) new_dtype_cast_node = graph_c.create_node('call_function', dtype_cast_op, (prev_node_c_inner, ), {}, new_dtype_cast_name) results.append(new_dtype_cast_node) return results else: raise AssertionError(f"type f{type(prev_node_c)} is not handled")
def _copy_node_from_a_to_c( node_a: Node, gm_a: GraphModule, gm_b: GraphModule, graph_c: Graph, ) -> Node: """ Simple copy of node_a to graph_c. """ if node_a.op == 'get_attr': node_a_copy_name = \ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type] if torch.is_tensor(node_a_obj): node_a_obj = node_a_obj.detach() setattr(gm_b, node_a_copy_name, node_a_obj) node_a_copy = graph_c.create_node(node_a.op, node_a_copy_name, (), {}, node_a_copy_name) return node_a_copy elif node_a.op == 'call_method': assert node_a.target in ('dequantize', 'to'), \ f"target {node_a.target} is not implemented" if node_a.target == 'dequantize': arg_copy = _copy_node_from_a_to_c( get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c) # type: ignore[arg-type] node_a_copy_name = \ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) node_a_copy = graph_c.create_node(node_a.op, node_a.target, (arg_copy, ), {}, node_a_copy_name) return node_a_copy else: # to arg_copy = _copy_node_from_a_to_c( get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c) # type: ignore[arg-type] node_a_copy_name = \ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) node_a_copy = graph_c.create_node( node_a.op, node_a.target, (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)), {}, node_a_copy_name) return node_a_copy else: raise AssertionError( f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented" )
def insert_observer( node: Node, observer: torch.quantization.ObserverBase, model: torch.nn.Module, activation_post_process_map: Dict[str, torch.quantization.ObserverBase], env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable, observed_node_names_set: Set[str]): """Insert observer for node by modifying the observed_graph and attach observer module to the model Args: node: Node observer: observer/fake_quantize module instance """ # respect device affinity when adding observers model_device = assert_and_get_unique_device(model) if model_device: observer.to(model_device) # add observer module as attribute prefix = node.name + '_activation_post_process_' get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) # put observer instance activation_post_process map assert activation_post_process_map is not None activation_post_process_map[node.name] = observer # insert observer call env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name)
def insert_observer( node: Node, observed_op: Node, observer: ObserverBase, model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, node_name_to_scope: Dict[str, Tuple[str, type]], input_or_output: str, ) -> Node: """ Attaches `observer` to `model`, and creates a node which calls `observer` on the output of `node`. """ model_device = assert_and_get_unique_device(model) if model_device: observer.to(model_device) # add observer module as attribute # NOTE: We get the FQN of the module/op being observed here using the node_name_to_scope # Please don't change/update this behavior as it might impact how observer stats are transferred # from the train model to the inference model for some models. obs_name_prefix, _ = node_name_to_scope[observed_op.name] obs_name_prefix = node.name if obs_name_prefix == '' else obs_name_prefix if is_equalization_observer(observer): prefix = node.name + '_equalization_process_' else: prefix = obs_name_prefix + '_' + input_or_output + '_activation_post_process_' get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) modules[observer_name] = observer with graph.inserting_after(node): new_obs = graph.create_node( 'call_module', observer_name, (node,), {}) return new_obs
def insert_observer( node: Node, observer: torch.quantization.ObserverBase, model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, ) -> Node: """ Attaches `observer` to `model`, and creates a node which calls `observer` on the output of `node`. """ model_device = assert_and_get_unique_device(model) if model_device: observer.to(model_device) # add observer module as attribute if is_equalization_observer(observer): prefix = node.name + '_equalization_process_' else: prefix = node.name + '_activation_post_process_' get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) modules[observer_name] = observer with graph.inserting_after(node): new_obs = graph.create_node( 'call_module', observer_name, (node,), {}) return new_obs
def fold_weight( quantized: QuantizedGraphModule, node_name_to_scope: Dict[str, Tuple[str, type]]) -> QuantizedGraphModule: """ Trace back from the weight node util we hit getattr, reconstruct the graph module with the traced nodes and run the graph module to pack the weight. then replace the original chain of ops with the packed weight. """ packed_weights = dict() # map from folded node name to the prepacked weight name folded_nodes = dict() # get packed weights for node in quantized.graph.nodes: if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS: nodes_to_fold = collect_producer_nodes(node) if nodes_to_fold is not None: for node_to_fold in nodes_to_fold: folded_nodes[node_to_fold.name] = node prepacking_module = graph_module_from_producer_nodes( quantized, nodes_to_fold) packed_weight = prepacking_module() packed_weights[node.name] = packed_weight # remove folded nodes and replace the prepacking node with getattr folded_graph = Graph() env: Dict[Any, Any] = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) quantized_root = quantized quantized_graph = quantized.graph for node in quantized_graph.nodes: prepack_node = folded_nodes.get(node.name, None) if prepack_node is node: packed_weight = packed_weights[node.name] # add a prepacked attribute to root op_node = list(prepack_node.users)[0] module_path, _ = node_name_to_scope[op_node.name] get_new_packed_weight_name = \ get_new_attr_name_with_prefix(module_path + '_packed_weight_') packed_weight_name = get_new_packed_weight_name(quantized_root) setattr(quantized_root, packed_weight_name, packed_weight) # replace prepack node with a getattr node env[node.name] = folded_graph.create_node('get_attr', packed_weight_name, (), {}) elif prepack_node is not None: # remove the foled node continue else: # copy other nodes env[node.name] = folded_graph.node_copy(node, load_arg) quantized = QuantizedGraphModule(quantized_root, folded_graph, quantized_root.preserved_attr_names) return quantized
def replace_observer_with_quantize_dequantize_node( model: torch.nn.Module, graph: Graph, node: Node, modules: Dict[str, torch.nn.Module], node_name_to_scope: Dict[str, Tuple[str, type]], qconfig_map: Dict[str, QConfigAny]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) module_path, prefix = get_module_path_and_prefix( node, node_name_to_scope, qconfig_map) observer_module = modules[node.target] maybe_quantize_node_info = get_quantize_node_info(observer_module) # Skip replacing observers to quant/dequant nodes if the qconfigs of all # consumers and producers of this observer are None skip_replacement = all([ has_none_qconfig(n, qconfig_map) for n in list(node.args) + list(node.users.keys()) ]) if skip_replacement or maybe_quantize_node_info is None: # didn't find correponding quantize op and info for the observer_module # so we just remove the observer with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) else: # otherwise, we can convert the observer moduel call to quantize/dequantize node node_type, quantize_op, qparams = maybe_quantize_node_info # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): # TODO: we can add the information of whether a value needs to # be registered as an attribute in qparams dict itself if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( model, graph, module_path + prefix + key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node, )) node.replace_all_uses_with(dequantized_node) graph.erase_node(node)
def create_node_from_old_node_preserve_meta( quantized_graph: Graph, create_node_args: Tuple[Any, ...], old_node: Node, ) -> Node: """ Creates `new_node` and copies the necessary metadata to it from `old_node`. """ new_node = quantized_graph.create_node(*create_node_args) new_node.stack_trace = old_node.stack_trace return new_node
def create_getattr_from_value(module: GraphModule, graph: Graph, prefix: str, value: Any) -> Node: """ Given a value of any type, creates a getattr node corresponding to the value and registers the value as a buffer to the module. """ get_new_attr_name = get_new_attr_name_with_prefix(prefix) attr_name = get_new_attr_name(module) module.register_buffer(attr_name, torch.tensor(value)) # Create get_attr with value attr_node = graph.create_node("get_attr", attr_name) return attr_node
def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node: """ Given a value of any type, creates a getattr node corresponding to the value and registers the value as a buffer to the module. """ get_new_attr_name = get_new_attr_name_with_prefix(prefix) attr_name = get_new_attr_name(module) device = assert_and_get_unique_device(module) new_value = value.clone().detach() if isinstance(value, torch.Tensor) \ else torch.tensor(value, device=device) module.register_buffer(attr_name, new_value) # Create get_attr with value attr_node = graph.create_node("get_attr", attr_name) return attr_node
def replace_observer_with_quantize_dequantize_node( graph: Graph, node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] root_module = modules[""] if observer_module.dtype == torch.float32: # remove the node for now # TODO: support dynamic quant with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) elif observer_module.dtype in [ torch.quint8, torch.qint8, torch.float16 ]: node_type, quantize_op, qparams = get_quantize_node_info( observer_module) # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( root_module, graph, key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node, )) node.replace_all_uses_with(dequantized_node) graph.erase_node(node)
def _fold_weight(self, quantized): packed_weights = dict() # map from folded node name to the prepacked weight name folded_nodes = dict() # get packed weights for node in quantized.graph.nodes: if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS: nodes_to_fold = collect_producer_nodes(node) if nodes_to_fold is not None: for node_to_fold in nodes_to_fold: folded_nodes[node_to_fold.name] = node prepacking_module = graph_module_from_producer_nodes( quantized, nodes_to_fold) packed_weight = prepacking_module() packed_weights[node.name] = packed_weight # remove folded nodes and replace the prepacking node with getattr folded_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) get_new_packed_weight_name = get_new_attr_name_with_prefix( '_fx_pass_packed_weight_') quantized_root = quantized quantized_graph = quantized.graph for node in quantized_graph.nodes: prepack_node = folded_nodes.get(node.name, None) if prepack_node is node: packed_weight = packed_weights[node.name] # add a prepacked attribute to root packed_weight_name = get_new_packed_weight_name(quantized_root) setattr(quantized_root, packed_weight_name, packed_weight) # replace prepack node with a getattr node env[node.name] = folded_graph.create_node( 'get_attr', packed_weight_name, (), {}) elif prepack_node is not None: # remove the foled node continue else: # copy other nodes env[node.name] = folded_graph.node_copy(node, load_arg) folded_graph.output(load_arg(quantized_graph.result)) quantized = GraphModule(quantized_root, folded_graph) return quantized
def replace_target_nodes_with( fx_module: GraphModule, old_op: str, old_target: Target, new_op: str, new_target: Target, ): """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, and updates them to match the new op code and target""" new_graph = Graph() val_map : Dict[Node, Node] = {} for node in fx_module.graph.nodes: if node.op == old_op and node.target == old_target: args = map_arg(node.args, lambda n: val_map[n]) kwargs = map_arg(node.kwargs, lambda n: val_map[n]) assert isinstance(args, tuple) assert isinstance(kwargs, dict) val_map[node] = new_graph.create_node(new_op, new_target, args, kwargs, node.name) else: val_map[node] = new_graph.node_copy(node, lambda n : val_map[n]) fx_module.graph = new_graph
def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant): if not inplace: model = copy.deepcopy(model) self.is_dynamic_quant = is_dynamic_quant # TODO: allow user specified patterns if self.is_dynamic_quant: self.patterns = get_dynamic_quant_patterns() else: self.patterns = get_quant_patterns() propagate_qconfig_(model, qconfig_dict) if model.training: self._qat_swap_modules(model) self.modules = dict(model.named_modules()) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph) # match the patterns that will get quantized matches = self._find_matches(model.graph, self.modules, self.patterns) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuant object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed_node_names_set = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in model.graph.nodes: if node.name in observed_node_names_set: continue get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_') root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) def insert_observer(node, observer, device): observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) self.activation_post_process_map[node.name] = observer env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) if device: getattr(model, observer_name).to(device) # don't need to insert observer for output in dynamic quantization if self.is_dynamic_quant: continue if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed_node_names_set elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed_node_names_set.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: if node.args[0].name in observed_node_names_set: observed_node_names_set.add(node.name) elif qconfig is not None and obj.all_nodes: # observer for outputs new_observer = qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) insert_observer(node, new_observer, device) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.name not in observed_node_names_set and node.name in quants: observer_name = get_new_observer_name(model) _, qconfig, is_weight = quants[node.name] if qconfig is not None: new_observer = \ qconfig.weight() if is_weight else qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) if device: new_observer.to(device) self.activation_post_process_map[node.name] = new_observer setattr(model, observer_name, self.activation_post_process_map[node.name]) env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) observed_graph.output(load_arg(model.graph.result)) model = GraphModule(model, observed_graph) self.save_state(model) return model
def _insert_dtype_cast_after_node( node_a: Node, node_c: Node, prev_node_c: Union[Node, List[Node]], gm_a: GraphModule, gm_b: GraphModule, graph_c: Graph, node_name_prefix: str, logger_cls: Callable, node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], ) -> Union[Node, List[Node]]: """ Given a starting graph C (derived from graph B) of ... -> prev_node_c -> node_c -> ... And a corresponding related node_a, inserts the correct dtype cast node after prev_node_c to cast into the dtype expected by node_a, resulting in: dtype_cast / ... -> prev_node_c -> node_c -> ... For example, if node_c is an int8 op and node_a is an fp32 op, this function will insert a dequant. """ dtype_cast_op = None dtype_cast_mod_cls = None dtype_cast_scale = None dtype_cast_zero_point = None node_input_type_a, _node_output_type_a = \ get_node_first_input_and_output_type( node_a, gm_a, logger_cls, node_type_to_io_type_map) node_input_type_c, _node_output_type_c = \ get_node_first_input_and_output_type( node_c, gm_b, logger_cls, node_type_to_io_type_map) if ((node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.INT8) or (node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.FP16) or # TODO(future PR): determine the actual dtype of node_c, # the current code only works because dequantize works with # multiple input dtypes. (node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)): dtype_cast_op = torch.dequantize elif (node_input_type_a == node_input_type_c and node_input_type_a != NodeInputOrOutputType.UNKNOWN): dtype_cast_mod_cls = torch.nn.Identity elif (node_input_type_a == NodeInputOrOutputType.INT8 and node_input_type_c == NodeInputOrOutputType.FP32): # int8 shadows fp32, the dtype cast needs to quantize to int8 # with the right qparams. node_a_input_qparams = get_node_input_qparams( node_a, gm_a, node_type_to_io_type_map) if node_a_input_qparams is not None: dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment] dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams else: raise AssertionError( f"dtype cast from {node_input_type_c} {node_c.format_node()} to " + f"{node_input_type_a} {node_a.format_node()} needs to be implemented" ) if isinstance(prev_node_c, Node): new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) if dtype_cast_op: if dtype_cast_scale is not None and dtype_cast_zero_point is not None: return _insert_quantize_per_tensor_node( prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale, dtype_cast_zero_point, new_dtype_cast_name) else: return graph_c.create_node('call_function', dtype_cast_op, (prev_node_c, ), {}, new_dtype_cast_name) else: assert dtype_cast_mod_cls dtype_cast_mod = dtype_cast_mod_cls() setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) return graph_c.create_node('call_module', new_dtype_cast_name, (prev_node_c, ), {}, new_dtype_cast_name) elif isinstance(prev_node_c, list): results = [] for prev_node_c_inner in prev_node_c: new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) if dtype_cast_op: # TODO(future PR): add handling for quantize_per_tensor new_dtype_cast_node = graph_c.create_node( 'call_function', dtype_cast_op, (prev_node_c_inner, ), {}, new_dtype_cast_name) results.append(new_dtype_cast_node) else: assert dtype_cast_mod_cls dtype_cast_mod = dtype_cast_mod_cls() setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) new_dtype_cast_node = graph_c.create_node( 'call_module', new_dtype_cast_name, (prev_node_c_inner, ), {}, new_dtype_cast_name) results.append(new_dtype_cast_node) return results else: raise AssertionError(f"type f{type(prev_node_c)} is not handled")
def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module): """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. When we are preparing a standalone module: input of the module is observed in parent module, output of the module is observed in the standalone module. Returns: model(GraphModule): prepared standalone module with following attributes: _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that needs to be observed in parent module _output_is_observed(Bool): a boolean variable indicate whether the output of the custom module is observed or not """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} if not inplace: model = copy.deepcopy(model) additional_quant_patterns = prepare_custom_config_dict.get( "additional_quant_pattern", {}) self.patterns = get_default_quant_patterns().copy() for k, v in additional_quant_patterns.items(): self.patterns[k] = v flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: additional_qat_module_mapping = prepare_custom_config_dict.get( "additioanl_qat_module_mapping", {}) self._qat_swap_modules(model, additional_qat_module_mapping) self.modules = dict(model.named_modules()) convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph, qconfig_dict) # match the patterns that will get quantized standalone_module_names = prepare_custom_config_dict.get( "standalone_module_name", None) custom_module_class_mapping = prepare_custom_config_dict.get( "float_to_observed_custom_module_class", None) matches = self._find_matches(model.graph, self.modules, self.patterns, standalone_module_names, custom_module_class_mapping) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuant object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed_node_names_set = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) # indexes for the inputs that needs to be observed standalone_module_observed_input_idxs = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) get_new_observer_name = get_new_attr_name_with_prefix( 'activation_post_process_') result_node: Optional[Node] = None for node in model.graph.nodes: if node.op == 'output': observed_graph.output(load_arg(node.args[0])) result_node = node continue if node.name in observed_node_names_set: continue prefix = node.name + '_activation_post_process_' root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) if qconfig is None: continue def insert_observer(node, observer, device): get_new_observer_name = get_new_attr_name_with_prefix( prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) self.activation_post_process_map[node.name] = observer env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node), ), {}) observed_node_names_set.add(node.name) if device: getattr(model, observer_name).to(device) if isinstance(obj, CustomModuleQuantizeHandler): custom_module = self.modules[node.target] observed_custom_module_class = \ custom_module_class_mapping[type(custom_module)] observed_custom_module = \ observed_custom_module_class.from_float(custom_module) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_custom_module) # index for input of custom module that needs to be observed in parent standalone_module_input_idxs = None if isinstance(obj, StandaloneModuleQuantizeHandler): # observe standalone module standalone_module = self.modules[node.target] prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx observed_standalone_module = prepare( standalone_module, {'': qconfig}) observed_standalone_module.qconfig = qconfig standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs observed_standalone_module = mark_observed_standalone_module( observed_standalone_module) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_standalone_module) self.modules[node.target] = observed_standalone_module # don't need to insert observer for output if activation does not # need to be statically quantized if not activation_is_statically_quantized(qconfig): continue # inserting observers for output of observed module, or mark the output # as observed if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed_node_names_set elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed_node_names_set.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: if node.args[0].name in observed_node_names_set: observed_node_names_set.add(node.name) elif isinstance(obj, StandaloneModuleQuantizeHandler): assert node.op == 'call_module' output_is_observed = self.modules[ node.target]._output_is_observed if output_is_observed: observed_node_names_set.add(node.name) elif qconfig is not None and obj.all_nodes: # observer for outputs new_observer = qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) insert_observer(node, new_observer, device) # insert observer for input of standalone module if standalone_module_input_idxs is not None: for idx in standalone_module_input_idxs: if node.args[idx].name not in observed_node_names_set: new_observer = qconfig.activation() device = assert_and_get_unique_device(model) insert_observer(node.args[idx], new_observer, device) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.name not in observed_node_names_set and node.name in quants: if is_standalone_module and node.name in graph_inputs: # we'll insert observer for input of standalone module # in parent graph standalone_module_observed_input_idxs.append( graph_inputs.index(node.name)) continue get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) _, qconfig, is_weight = quants[node.name] if qconfig is not None: # TODO: use insert_observer new_observer = \ qconfig.weight() if is_weight else qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) if device: new_observer.to(device) self.activation_post_process_map[node.name] = new_observer setattr(model, observer_name, self.activation_post_process_map[node.name]) env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node), ), {}) observed_node_names_set.add(node.name) model = GraphModule(model, observed_graph) self.save_state(model) if is_standalone_module: assert result_node is not None assert isinstance(result_node.args[0], Node), \ 'standalone module returning dict is not yet supported' # indicator for whether output is observed or not. # This used for correctly quantize standalone modules output_is_observed = result_node.args[ 0].name in observed_node_names_set model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs model._output_is_observed = output_is_observed return model
class Quantizer: def __init__(self): # mapping from matched node to activation_post_process # must be filled before convert self.activation_post_process_map = None def _qat_swap_modules(self, root): convert(root, mapping=DEFAULT_QAT_MODULE_MAPPING, inplace=True, remove_qconfig=False) def _generate_qconfig_map(self, root, input_graph): def get_qconfig(module): return module.qconfig if hasattr(module, 'qconfig') else None self.qconfig_map = dict() for node in input_graph.nodes: if node.op == 'get_param': parent, _ = _parent_name(node.target) self.qconfig_map[node.name] = get_qconfig(self.modules[parent]) elif node.op == 'call_function': self.qconfig_map[node.name] = get_qconfig(root) elif node.op == 'call_method': self_obj = node.args[0] # qconfig for call_method should be the same as the `self` object for the call self.qconfig_map[node.name] = self.qconfig_map[self_obj.name] elif node.op == 'call_module': self.qconfig_map[node.name] = get_qconfig( self.modules[node.target]) def _prepare(self, model, qconfig_dict, inplace, quant_type): input_root = model.root if not inplace: input_root = copy.deepcopy(input_root) input_graph = model.graph self.quant_type = quant_type # TODO: allow user specified patterns if self.quant_type == QuantType.DYNAMIC: self.patterns = get_dynamic_quant_patterns() else: self.patterns = get_quant_patterns() propagate_qconfig_(input_root, qconfig_dict) if input_root.training: self._qat_swap_modules(input_root) self.modules = dict(input_root.named_modules()) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(input_root, input_graph) # match the patterns that will get quantized matches = self._find_matches(input_graph, self.modules, self.patterns) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuant object for each quants = self._find_quants(input_graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: if node.name in observed: continue def get_new_observer_name(parent_module): i = 0 def get_observer_name(i): return 'activation_post_process_' + str(i) observer_name = get_observer_name(i) while hasattr(parent_module, observer_name): i += 1 observer_name = get_observer_name(i) return observer_name root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) def insert_observer(node, observer): observer_name = get_new_observer_name(input_root) setattr(input_root, observer_name, observer) self.activation_post_process_map[node.name] = observer env[node.name] = observed_graph.create_node( 'call_module', observer_name, [load_arg(node)], {}) observed.add(node.name) # don't need to insert observer for output in dynamic quantization if self.quant_type == QuantType.DYNAMIC: continue if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: if node.args[0].name in observed: observed.add(node.name) elif qconfig is not None and obj.all_nodes: # observer for outputs insert_observer(node, qconfig.activation()) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.name not in observed and node.name in quants: observer_name = get_new_observer_name(input_root) _, qconfig, is_weight = quants[node.name] if qconfig is not None: self.activation_post_process_map[ node.name] = qconfig.weight( ) if is_weight else qconfig.activation() setattr(input_root, observer_name, self.activation_post_process_map[node.name]) env[node.name] = observed_graph.create_node( 'call_module', observer_name, [load_arg(node)], {}) observed.add(node.name) observed_graph.output(load_arg(input_graph.result)) return GraphModule(input_root, observed_graph) def prepare(self, model, qconfig_dict, inplace=False): return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.STATIC) def prepare_dynamic(self, model, qconfig_dict, inplace=False): return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC) def convert(self, observed, inplace=False, debug=False): assert self.activation_post_process_map is not None # move to cpu since we only have quantized cpu kernels observed.eval().cpu() observed_root = observed.root observed_graph = observed.graph if not inplace: observed_root = copy.deepcopy(observed_root) self.modules = dict(observed_root.named_modules()) matches = self._find_matches(observed.graph, self.modules, self.patterns) quants = self._find_quants(observed.graph, matches) self.quantized_graph = Graph() env = {} quant_env = {} def load_non_quantized(n): if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find node:' + n.name + \ ' in quantized environment:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n): if n.name not in quant_env: assert n.name in env, \ 'trying to load quantized node but did not find node:' + n.name + \ ' in float environment:' + str(env) assert n.name in quants, 'did not find quant object for node:' + n.name quant = quants[n.name][0] quant_env[n.name] = quant.convert(self, env[n.name]) return quant_env[n.name] def load_x(n): assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either of the environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized): """ if quantized is a list, then arg should be a list and the args with corresponding indexes will be quantized if quantized is a boolean, then all args will be quantized/not quantized if quantized is None, then we'll load the node as long as it exists """ assert quantized is None or isinstance( quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg): if quantized is None: return map_arg(arg, load_x) if isinstance(quantized, bool): return map_arg( arg, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg, (tuple, list)), arg loaded_arg = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg): if i in quantized: loaded_arg.append(map_arg(a, load_quantized)) else: loaded_arg.append(map_arg(a, load_non_quantized)) return type(arg)(loaded_arg) return load_arg_impl def is_quantized(node): if isinstance(node, Node): assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment' # there might be nodes appearing in both environemnts, but quant_env will take # precedence if node.name in quant_env: return True elif node.name in env: return False elif isinstance(node, list): quantized = map(is_quantized, node) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") for node in observed_graph.nodes: root_node, matched, obj, qconfig = matches.get( node.name, (None, None, None, None)) if root_node is node: result = obj.convert(self, node, load_arg) quantized = True # Need to get correct quantized/non-quantized state for the output of CopyNode if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = is_quantized(node.args[0]) if self.quant_type == QuantType.DYNAMIC: quantized = False if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module': if node.target.split('.')[-1].startswith( 'activation_post_process_'): observer_module = self.modules[node.target] prev_node = node.args[0] if prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the activation_post_process quant_env[node.name] = quant_env[prev_node.name] continue # replace activation post process with quantization ops parent_name = '' scale, zero_point = observer_module.calculate_qparams() # TODO: per channel scale = float(scale) zero_point = int(zero_point) dtype = observer_module.dtype qparams = { '_scale_': scale, '_zero_point_': zero_point, '_dtype_': dtype } i = 0 def noattr(module, qparams, i): for name in qparams.keys(): if hasattr(module, name + str(i)): return False return True def get_next_i(module, qparams): i = 0 while not noattr(module, qparams, i): i += 1 return i parent_module = self.modules[parent_name] i = get_next_i(parent_module, qparams) inputs = [load_non_quantized(node.args[0])] for key, value in qparams.items(): setattr(parent_module, key + str(i), value) qparam_full_path = key + str(i) if parent_name: qparam_full_path = parent_name + '.' + qparam_full_path inputs.append( self.quantized_graph.create_node( 'get_param', qparam_full_path)) quant_env[node.name] = self.quantized_graph.create_node( 'call_function', torch.quantize_per_tensor, inputs, {}) continue # dequantize inputs for the node that are not quantized env[node.name] = self.quantized_graph.node_copy( node, load_non_quantized) self.quantized_graph.output(load_non_quantized(observed_graph.result)) to_be_removed = [] for name, _ in observed_root.named_modules(): if name.split('.')[-1].startswith('activation_post_process_'): to_be_removed.append(name) for n in to_be_removed: delattr(observed_root, n) return GraphModule(observed_root, self.quantized_graph) def _find_matches(self, graph, modules, patterns): match_map = {} # node name -> (root_node, match_value?) all_matched = set() def record_match(pattern, node, matched): if isinstance(pattern, tuple): s, *args = pattern record_match(s, node, matched) if pattern[0] is not getattr: for subpattern, arg in zip(args, node.args): record_match(subpattern, arg, matched) else: matched.append(node) for node in reversed(graph.nodes): if node.name not in match_map and node.name not in all_matched: for pattern, value in patterns.items(): if matches(modules, node, pattern): matched = [] record_match(pattern, node, matched) for n in matched: match_map[n.name] = (node, matched, value(self, node), self.qconfig_map[n.name]) all_matched.add(n.name) # break after finding the first match break return match_map def _find_quants(self, graph, matches): quants = {} def visit(node, qconfig): def visit_arg(arg): # note: we have to measure quantization information # even for nodes where we might not use it because it is already # quantized. This is because each match has the option to # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) is_weight = False if isinstance( node, Node ) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT: for i, node_arg in enumerate(node.args): if arg is node_arg and i in WEIGHT_INDEX_DICT[ node.target]: is_weight = True if self.quant_type != QuantType.DYNAMIC or is_weight: # overwrite previous quant config quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight) return visit_arg for node in graph.nodes: if node.name in matches: root_node, matched, obj, qconfig = matches[node.name] # don't attach observer/fake_quant for CopyNode if isinstance(obj, CopyNode): qconfig = None if root_node is node: # matched[-1] is the first op in the sequence and # matched[0] is the last op in the sequence # inputs map_arg(matched[-1].args, visit(matched[-1], qconfig)) map_arg(matched[-1].kwargs, visit(matched[-1], qconfig)) # output map_arg(matched[0], visit(None, qconfig)) return quants
class Quantizer: def __init__(self): # mapping from matched node to activation_post_process # must be filled before convert self.activation_post_process_map = None # mapping from node name to qconfig that should be used for that node # filled out for a model during _generate_qconfig_map self.qconfig_map = None # mapping from fully qualified module name to module instance # for example, # { # '': Model(...), # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } self.modules = None # mapping from a tuple of nodes in reverse order to uninitialized # QuantizeHandler subclass. For example, # { # # match a single node # (<class 'torch.nn.modules.conv.Conv3d'>: # <class 'torch.quantization.fx.quantize.ConvRelu'>), # # match multiple nodes in reverse order # ((<function relu at 0x7f766a7360d0>, <built-in function add>): # <class 'torch.quantization.fx.quantize.Add'>), # } self.patterns = None def _qat_swap_modules(self, root): convert(root, mapping=DEFAULT_QAT_MODULE_MAPPING, inplace=True, remove_qconfig=False) def _generate_qconfig_map(self, root, input_graph): def get_qconfig(module): return module.qconfig if hasattr(module, 'qconfig') else None self.qconfig_map = dict() for node in input_graph.nodes: if node.op == 'get_param': parent, _ = _parent_name(node.target) self.qconfig_map[node.name] = get_qconfig(self.modules[parent]) elif node.op == 'call_function': self.qconfig_map[node.name] = get_qconfig(root) elif node.op == 'call_method': self_obj = node.args[0] # qconfig for call_method should be the same as the `self` object for the call self.qconfig_map[node.name] = self.qconfig_map[self_obj.name] elif node.op == 'call_module': self.qconfig_map[node.name] = get_qconfig( self.modules[node.target]) def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant): assert not inplace, 'inplace prepare is not supported yet' input_root = model if not inplace: input_root = copy.deepcopy(input_root) input_graph = model.graph self.is_dynamic_quant = is_dynamic_quant # TODO: allow user specified patterns if self.is_dynamic_quant: self.patterns = get_dynamic_quant_patterns() else: self.patterns = get_quant_patterns() propagate_qconfig_(input_root, qconfig_dict) if input_root.training: self._qat_swap_modules(input_root) self.modules = dict(input_root.named_modules()) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(input_root, input_graph) # match the patterns that will get quantized matches = self._find_matches(input_graph, self.modules, self.patterns) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuant object for each quants = self._find_quants(input_graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: if node.name in observed: continue get_new_observer_name = get_new_attr_name_with_prefix( 'activation_post_process_') root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) def insert_observer(node, observer, device): observer_name = get_new_observer_name(input_root) setattr(input_root, observer_name, observer) self.activation_post_process_map[node.name] = observer env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node), ), {}) observed.add(node.name) if device: getattr(input_root, observer_name).to(device) # don't need to insert observer for output in dynamic quantization if self.is_dynamic_quant: continue if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: if node.args[0].name in observed: observed.add(node.name) elif qconfig is not None and obj.all_nodes: # observer for outputs new_observer = qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(input_root) insert_observer(node, new_observer, device) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.name not in observed and node.name in quants: observer_name = get_new_observer_name(input_root) _, qconfig, is_weight = quants[node.name] if qconfig is not None: new_observer = \ qconfig.weight() if is_weight else qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(input_root) if device: new_observer.to(device) self.activation_post_process_map[node.name] = new_observer setattr(input_root, observer_name, self.activation_post_process_map[node.name]) env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node), ), {}) observed.add(node.name) observed_graph.output(load_arg(input_graph.result)) observed = GraphModule(input_root, observed_graph) self.save_state(observed) return observed def save_state(self, observed): observed._activation_post_process_map = self.activation_post_process_map observed._patterns = self.patterns observed._qconfig_map = self.qconfig_map def restore_state(self, observed): err_msg = 'please make sure the model is produced by prepare' assert hasattr(observed, '_activation_post_process_map'), 'did not found ' + \ '_activation_post_process attribute ' + err_msg assert hasattr(observed, '_patterns'), 'did not found ' + \ '_patterns attribute ' + err_msg assert hasattr(observed, '_qconfig_map'), 'did not found ' + \ '_qconfig_map attribute ' + err_msg self.activation_post_process_map = observed._activation_post_process_map self.patterns = observed._patterns self.qconfig_map = observed._qconfig_map def prepare(self, model, qconfig_dict, inplace=False): return self._prepare(model, qconfig_dict, inplace, is_dynamic_quant=False) def prepare_dynamic(self, model, qconfig_dict, inplace=False): return self._prepare(model, qconfig_dict, inplace, is_dynamic_quant=True) def _run_weight_observers(self, observed): r''' Extract the subgraph that produces the weight for dynamically quantized node and run the subgraph to observe the weight. Note that the observers of dynamically quantized modules are run during the conversion step. ''' for node in observed.graph.nodes: if node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT: for i, node_arg in enumerate(node.args): if i in WEIGHT_INDEX_DICT[node.target]: # node_arg is weight weight_observer_nodes = collect_producer_nodes( node_arg) if weight_observer_nodes is not None: weight_observer_module = graph_module_from_producer_nodes( observed, weight_observer_nodes) # run the weight observer weight_observer_module() return def _convert(self, observed, inplace=False, debug=False, is_dynamic_quant=False): assert not inplace, 'inplace convert is not supported yet' self.restore_state(observed) self.is_dynamic_quant = is_dynamic_quant # run weight observers before inserting quant dequant nodes # for dynamic quantization if self.is_dynamic_quant: self._run_weight_observers(observed) # move to cpu since we only have quantized cpu kernels observed.eval().cpu() observed_root = observed observed_graph = observed.graph if not inplace: observed_root = copy.deepcopy(observed_root) self.modules = dict(observed_root.named_modules()) matches = self._find_matches(observed.graph, self.modules, self.patterns) quants = self._find_quants(observed.graph, matches) self.quantized_graph = Graph() env = {} quant_env = {} def load_non_quantized(n): if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find node:' + n.name + \ ' in quantized environment:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n): if n.name not in quant_env: assert n.name in env, \ 'trying to load quantized node but did not find node:' + n.name + \ ' in float environment:' + str(env) assert n.name in quants, 'did not find quant object for node:' + n.name quant = quants[n.name][0] quant_env[n.name] = quant.convert(self, env[n.name]) return quant_env[n.name] def load_x(n): assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either of the environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized): """ Input: quantized, which can be None, list, boolean or tuple - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized - if quantized is a boolean, then all args will be quantized/not quantized - if quantized is None, then we'll load the node as long as it exists Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or isinstance( quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg_or_args): if quantized is None: return map_arg(arg_or_args, load_x) if isinstance(quantized, bool): return map_arg( arg_or_args, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl def is_quantized(node): if isinstance(node, Node): assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment' # there might be nodes appearing in both environemnts, but quant_env will take # precedence if node.name in quant_env: return True elif node.name in env: return False elif isinstance(node, list): quantized = map(is_quantized, node) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") for node in observed_graph.nodes: root_node, matched, obj, qconfig = matches.get( node.name, (None, None, None, None)) if root_node is node: result = obj.convert(self, node, load_arg) quantized = True # Need to get correct quantized/non-quantized state for the output of CopyNode if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = is_quantized(node.args[0]) # output of dynamic quantization is not quantized if self.is_dynamic_quant: quantized = False if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module': if node.target.split('.')[-1].startswith( 'activation_post_process_'): observer_module = self.modules[node.target] prev_node = node.args[0] if prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the activation_post_process quant_env[node.name] = quant_env[prev_node.name] continue # replace activation post process with quantization ops parent_name = '' scale, zero_point = observer_module.calculate_qparams() dtype = observer_module.dtype def is_per_channel(qscheme): return qscheme == torch.per_channel_affine or \ qscheme == torch.per_channel_symmetric if is_per_channel(observer_module.qscheme): ch_axis = int(observer_module.ch_axis) qparams = { '_scale_': scale, '_zero_point_': zero_point, '_axis': ch_axis, '_dtype_': dtype } quantize_op = torch.quantize_per_channel else: scale = float(scale) zero_point = int(zero_point) qparams = { '_scale_': scale, '_zero_point_': zero_point, '_dtype_': dtype } quantize_op = torch.quantize_per_tensor i = 0 def noattr(module, qparams, i): for name in qparams.keys(): if hasattr(module, name + str(i)): return False return True def get_next_i(module, qparams): i = 0 while not noattr(module, qparams, i): i += 1 return i parent_module = self.modules[parent_name] i = get_next_i(parent_module, qparams) inputs = [load_non_quantized(node.args[0])] for key, value in qparams.items(): setattr(parent_module, key + str(i), value) qparam_full_path = key + str(i) if parent_name: qparam_full_path = parent_name + '.' + qparam_full_path inputs.append( self.quantized_graph.create_node( 'get_param', qparam_full_path)) quant_env[node.name] = self.quantized_graph.create_node( 'call_function', quantize_op, tuple(inputs), {}) continue # dequantize inputs for the node that are not quantized env[node.name] = self.quantized_graph.node_copy( node, load_non_quantized) self.quantized_graph.output(load_non_quantized(observed_graph.result)) to_be_removed = [] for name, _ in observed_root.named_modules(): if name.split('.')[-1].startswith('activation_post_process_'): to_be_removed.append(name) for n in to_be_removed: delattr(observed_root, n) return GraphModule(observed_root, self.quantized_graph) # Trace back from the weight node util we hit getattr, reconstruct the graph module # with the traced nodes and run the graph module to pack the weight. then replace # the original chain of ops with the packed weight. def _fold_weight(self, quantized): packed_weights = dict() # map from folded node name to the prepacked weight name folded_nodes = dict() # get packed weights for node in quantized.graph.nodes: if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS: nodes_to_fold = collect_producer_nodes(node) if nodes_to_fold is not None: for node_to_fold in nodes_to_fold: folded_nodes[node_to_fold.name] = node prepacking_module = graph_module_from_producer_nodes( quantized, nodes_to_fold) packed_weight = prepacking_module() packed_weights[node.name] = packed_weight # remove folded nodes and replace the prepacking node with getattr folded_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) get_new_packed_weight_name = get_new_attr_name_with_prefix( '_fx_pass_packed_weight_') quantized_root = quantized quantized_graph = quantized.graph for node in quantized_graph.nodes: prepack_node = folded_nodes.get(node.name, None) if prepack_node is node: packed_weight = packed_weights[node.name] # add a prepacked attribute to root packed_weight_name = get_new_packed_weight_name(quantized_root) setattr(quantized_root, packed_weight_name, packed_weight) # replace prepack node with a getattr node env[node.name] = folded_graph.create_node( 'get_param', packed_weight_name, (), {}) elif prepack_node is not None: # remove the foled node continue else: # copy other nodes env[node.name] = folded_graph.node_copy(node, load_arg) folded_graph.output(load_arg(quantized_graph.result)) return GraphModule(quantized_root, folded_graph) def convert(self, observed, inplace=False, debug=False, is_dynamic=False): quantized = self._convert(observed, inplace, debug, is_dynamic) if not debug: quantized = self._fold_weight(quantized) return quantized def _find_matches(self, graph, modules, patterns): match_map = {} # node name -> (root_node, match_value?) all_matched = set() def record_match(pattern, node, matched): if isinstance(pattern, tuple): s, *args = pattern record_match(s, node, matched) if pattern[0] is not getattr: for subpattern, arg in zip(args, node.args): record_match(subpattern, arg, matched) else: matched.append(node) for node in reversed(graph.nodes): if node.name not in match_map and node.name not in all_matched: for pattern, value in patterns.items(): if is_match(modules, node, pattern): matched = [] record_match(pattern, node, matched) for n in matched: match_map[n.name] = (node, matched, value(self, node), self.qconfig_map[n.name]) all_matched.add(n.name) # break after finding the first match break return match_map def _find_quants(self, graph, matches): quants = {} def visit(node, qconfig): def visit_arg(arg): # note: we have to measure quantization information # even for nodes where we might not use it because it is already # quantized. This is because each match has the option to # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) is_weight = False if isinstance( node, Node ) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT: for i, node_arg in enumerate(node.args): if arg is node_arg and i in WEIGHT_INDEX_DICT[ node.target]: is_weight = True if (not self.is_dynamic_quant) or is_weight: # overwrite previous quant config quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight) return visit_arg for node in graph.nodes: if node.name in matches: root_node, matched, obj, qconfig = matches[node.name] # don't attach observer/fake_quant for CopyNode if isinstance(obj, CopyNode): qconfig = None if root_node is node: # matched[-1] is the first op in the sequence and # matched[0] is the last op in the sequence # inputs map_arg(matched[-1].args, visit(matched[-1], qconfig)) map_arg(matched[-1].kwargs, visit(matched[-1], qconfig)) # output map_arg(matched[0], visit(None, qconfig)) return quants
def _insert_dtype_cast_after_node( node_a: Node, node_c: Node, prev_node_c: Union[Node, List[Node]], gm_a: GraphModule, gm_b: GraphModule, graph_c: Graph, node_name_prefix: str, logger_cls: Callable, ) -> Union[Node, List[Node]]: """ Given a starting graph C (derived from graph B) of ... -> prev_node_c -> node_c -> ... And a corresponding related node_a, inserts the correct dtype cast node after prev_node_c to cast into the dtype expected by node_a, resulting in: dtype_cast / ... -> prev_node_c -> node_c -> ... For example, if node_c is an int8 op and node_a is an fp32 op, this function will insert a dequant. """ dtype_cast_op = None dtype_cast_mod_cls = None node_input_type_a, _node_output_type_a = \ get_node_first_input_and_output_type(node_a, gm_a, logger_cls) node_input_type_c, _node_output_type_c = \ get_node_first_input_and_output_type(node_c, gm_b, logger_cls) if ((node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.INT8) or (node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.FP16)): dtype_cast_op = torch.dequantize elif (node_input_type_a == NodeInputOrOutputType.FP32 and node_input_type_c == NodeInputOrOutputType.FP32): dtype_cast_mod_cls = torch.nn.Identity elif (node_input_type_a == NodeInputOrOutputType.INT8 and node_input_type_c == NodeInputOrOutputType.INT8): dtype_cast_mod_cls = torch.nn.Identity elif (node_input_type_a == NodeInputOrOutputType.FP32_OR_INT8 and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8): dtype_cast_mod_cls = torch.nn.Identity else: raise AssertionError( f"dtype cast from {node_input_type_c} {node_c.format_node()} to " + f"{node_input_type_a} {node_a.format_node()} needs to be implemented" ) if isinstance(prev_node_c, Node): new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) if dtype_cast_op: return graph_c.create_node('call_function', dtype_cast_op, (prev_node_c, ), {}, new_dtype_cast_name) else: assert dtype_cast_mod_cls dtype_cast_mod = dtype_cast_mod_cls() setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) return graph_c.create_node('call_module', new_dtype_cast_name, (prev_node_c, ), {}, new_dtype_cast_name) elif isinstance(prev_node_c, list): results = [] for prev_node_c_inner in prev_node_c: new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) if dtype_cast_op: new_dtype_cast_node = graph_c.create_node( 'call_function', dtype_cast_op, (prev_node_c_inner, ), {}, new_dtype_cast_name) results.append(new_dtype_cast_node) else: assert dtype_cast_mod_cls dtype_cast_mod = dtype_cast_mod_cls() setattr(gm_b, new_dtype_cast_name, dtype_cast_mod) new_dtype_cast_node = graph_c.create_node( 'call_module', new_dtype_cast_name, (prev_node_c, ), {}, new_dtype_cast_name) results.append(new_dtype_cast_node) return results else: raise AssertionError(f"type f{type(prev_node_c)} is not handled")
def _prepare(self, model, qconfig_dict, inplace, quant_type): input_root = model.root if not inplace: input_root = copy.deepcopy(input_root) input_graph = model.graph self.quant_type = quant_type # TODO: allow user specified patterns if self.quant_type == QuantType.DYNAMIC: self.patterns = get_dynamic_quant_patterns() else: self.patterns = get_quant_patterns() propagate_qconfig_(input_root, qconfig_dict) if input_root.training: self._qat_swap_modules(input_root) self.modules = dict(input_root.named_modules()) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(input_root, input_graph) # match the patterns that will get quantized matches = self._find_matches(input_graph, self.modules, self.patterns) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuant object for each quants = self._find_quants(input_graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: if node.name in observed: continue def get_new_observer_name(parent_module): i = 0 def get_observer_name(i): return 'activation_post_process_' + str(i) observer_name = get_observer_name(i) while hasattr(parent_module, observer_name): i += 1 observer_name = get_observer_name(i) return observer_name root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) def insert_observer(node, observer): observer_name = get_new_observer_name(input_root) setattr(input_root, observer_name, observer) self.activation_post_process_map[node.name] = observer env[node.name] = observed_graph.create_node( 'call_module', observer_name, [load_arg(node)], {}) observed.add(node.name) # don't need to insert observer for output in dynamic quantization if self.quant_type == QuantType.DYNAMIC: continue if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: if node.args[0].name in observed: observed.add(node.name) elif qconfig is not None and obj.all_nodes: # observer for outputs insert_observer(node, qconfig.activation()) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.name not in observed and node.name in quants: observer_name = get_new_observer_name(input_root) _, qconfig, is_weight = quants[node.name] if qconfig is not None: self.activation_post_process_map[ node.name] = qconfig.weight( ) if is_weight else qconfig.activation() setattr(input_root, observer_name, self.activation_post_process_map[node.name]) env[node.name] = observed_graph.create_node( 'call_module', observer_name, [load_arg(node)], {}) observed.add(node.name) observed_graph.output(load_arg(input_graph.result)) return GraphModule(input_root, observed_graph)
class PrimContext(torch.overrides.TorchFunctionMode): """ The prototype prim tracing context. Example usage: import torch._prims.utils as utils from torch._prims.context import PrimContext from torch._prims.executor import execute from torch.overrides import push_torch_function_mode a = torch.randn((2, 2)) b = torch.randn((2, 2)) with push_torch_function_mode(PrimContext): meta_a = ctx.placeholder(utils.TensorMeta(a)) meta_b = ctx.placeholder(utils.TensorMeta(b)) result = torch.add(meta_a, meta_b) ctx.output(result) exc_result = execute(ctx, a, b) Currently this only acquires a trace of prims, and it does not account for control flow. As such, execute must be called with tensors that have the same metadata (dtype, device, shape...) as the tensors used to trace the operations. The tracing context's FX graph can be acquired using its graph attribute. """ def __init__(self): self.graph = Graph() # Private attributes for generating names self._tensor_name_counter = 0 self._dim_name_counter = 0 self._shape_name_counter = 0 self._lowercase = tuple(string.ascii_lowercase) self._uppercase = tuple(string.ascii_uppercase) @staticmethod def _create_name(idx, chars): name = "" while idx >= len(chars): name = chars[idx % len(chars)] + name idx = idx - len(chars) name = chars[idx] + name return name def _tensor_name(self): idx = self._tensor_name_counter self._tensor_name_counter = self._tensor_name_counter + 1 return self._create_name(idx, self._lowercase) def _add_user(self, tm: TensorMeta, node: Node) -> None: assert tm.node is not None tm.node.users[node] = None def placeholder(self, a: Any): name = self._tensor_name() node = self.graph.placeholder(name) if isinstance(a, TensorMeta): if a.node is not None: raise ValueError( "Attempting to reuse a TensorMeta in a new trace!") a.tname = name a.node = node return a def output(self, tm: TensorMeta): # TODO: allow other output types assert isinstance(tm, TensorMeta) node = self.graph.output(tm) self._add_user(tm, node) def __torch_function__( self, func: Callable, types: Sequence, args: Sequence[Any] = (), kwargs: Dict = None, ): """ Determines which function to call. The order of which function is called is determined by: - func's "meta" attribute, if it exists - if func is a torch operation, its corresponding reference - func """ if kwargs is None: kwargs = {} if hasattr(func, "meta"): # TODO: add check that all args/kwargs are 'registered' properly # to this trace output = func.meta(*args, **kwargs) # type: ignore[attr-defined] # Updates graph # TODO: handle outputs with multiple tensors # TODO: handle non-tensor outputs assert isinstance(output, TensorMeta) output_name = self._tensor_name() node = self.graph.create_node("call_function", func, name=output_name, args=args, kwargs=kwargs) output.tname = output_name output.node = node # Marks uses for x in (x for x in chain(args, kwargs.values()) if isinstance(x, TensorMeta)): self._add_user(x, node) return output # Remaps torch operations to their references if func in _torch_to_reference_map: fn = _torch_to_reference_map[func] with torch.overrides.enable_torch_function_mode( self, replace=self.inner): return fn(*args, **kwargs) # type: ignore[operator] return func(*args, **kwargs)