def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module): """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. When we are preparing a standalone module: input of the module is observed in parent module, output of the module is observed in the standalone module. Returns: model(GraphModule): prepared standalone module with following attributes: _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that needs to be observed in parent module _output_is_observed(Bool): a boolean variable indicate whether the output of the custom module is observed or not """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} if not inplace: model = copy.deepcopy(model) additional_quant_patterns = prepare_custom_config_dict.get( "additional_quant_pattern", {}) self.patterns = get_default_quant_patterns().copy() for k, v in additional_quant_patterns.items(): self.patterns[k] = v flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: additional_qat_module_mapping = prepare_custom_config_dict.get( "additioanl_qat_module_mapping", {}) self._qat_swap_modules(model, additional_qat_module_mapping) self.modules = dict(model.named_modules()) convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph, qconfig_dict) # match the patterns that will get quantized standalone_module_names = prepare_custom_config_dict.get( "standalone_module_name", None) custom_module_class_mapping = prepare_custom_config_dict.get( "float_to_observed_custom_module_class", None) matches = self._find_matches(model.graph, self.modules, self.patterns, standalone_module_names, custom_module_class_mapping) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuant object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed_node_names_set = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) # indexes for the inputs that needs to be observed standalone_module_observed_input_idxs = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) get_new_observer_name = get_new_attr_name_with_prefix( 'activation_post_process_') result_node: Optional[Node] = None for node in model.graph.nodes: if node.op == 'output': observed_graph.output(load_arg(node.args[0])) result_node = node continue if node.name in observed_node_names_set: continue prefix = node.name + '_activation_post_process_' root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) if qconfig is None: continue def insert_observer(node, observer, device): get_new_observer_name = get_new_attr_name_with_prefix( prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) self.activation_post_process_map[node.name] = observer env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node), ), {}) observed_node_names_set.add(node.name) if device: getattr(model, observer_name).to(device) if isinstance(obj, CustomModuleQuantizeHandler): custom_module = self.modules[node.target] observed_custom_module_class = \ custom_module_class_mapping[type(custom_module)] observed_custom_module = \ observed_custom_module_class.from_float(custom_module) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_custom_module) # index for input of custom module that needs to be observed in parent standalone_module_input_idxs = None if isinstance(obj, StandaloneModuleQuantizeHandler): # observe standalone module standalone_module = self.modules[node.target] prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx observed_standalone_module = prepare( standalone_module, {'': qconfig}) observed_standalone_module.qconfig = qconfig standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs observed_standalone_module = mark_observed_standalone_module( observed_standalone_module) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_standalone_module) self.modules[node.target] = observed_standalone_module # don't need to insert observer for output if activation does not # need to be statically quantized if not activation_is_statically_quantized(qconfig): continue # inserting observers for output of observed module, or mark the output # as observed if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed_node_names_set elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed_node_names_set.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: if node.args[0].name in observed_node_names_set: observed_node_names_set.add(node.name) elif isinstance(obj, StandaloneModuleQuantizeHandler): assert node.op == 'call_module' output_is_observed = self.modules[ node.target]._output_is_observed if output_is_observed: observed_node_names_set.add(node.name) elif qconfig is not None and obj.all_nodes: # observer for outputs new_observer = qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) insert_observer(node, new_observer, device) # insert observer for input of standalone module if standalone_module_input_idxs is not None: for idx in standalone_module_input_idxs: if node.args[idx].name not in observed_node_names_set: new_observer = qconfig.activation() device = assert_and_get_unique_device(model) insert_observer(node.args[idx], new_observer, device) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.name not in observed_node_names_set and node.name in quants: if is_standalone_module and node.name in graph_inputs: # we'll insert observer for input of standalone module # in parent graph standalone_module_observed_input_idxs.append( graph_inputs.index(node.name)) continue get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) _, qconfig, is_weight = quants[node.name] if qconfig is not None: # TODO: use insert_observer new_observer = \ qconfig.weight() if is_weight else qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) if device: new_observer.to(device) self.activation_post_process_map[node.name] = new_observer setattr(model, observer_name, self.activation_post_process_map[node.name]) env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node), ), {}) observed_node_names_set.add(node.name) model = GraphModule(model, observed_graph) self.save_state(model) if is_standalone_module: assert result_node is not None assert isinstance(result_node.args[0], Node), \ 'standalone module returning dict is not yet supported' # indicator for whether output is observed or not. # This used for correctly quantize standalone modules output_is_observed = result_node.args[ 0].name in observed_node_names_set model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs model._output_is_observed = output_is_observed return model
def _prepare(self, model, qconfig_dict, prepare_custom_config_dict, is_standalone_module): """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. When we are preparing a standalone module: input of the module is observed in parent module, output of the module is observed in the standalone module. Returns: model(GraphModule): prepared standalone module with following attributes: _standalone_module_observed_input_idxs(List[Int]): a list of indexes for the graph inputs that needs to be observed in parent module _output_is_observed(Bool): a boolean variable indicate whether the output of the custom module is observed or not """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} additional_quant_patterns = \ prepare_custom_config_dict.get("additional_quant_pattern", {}) self.patterns = get_combined_dict(get_default_quant_patterns(), additional_quant_patterns) flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: additional_qat_module_mapping = prepare_custom_config_dict.get( "additional_qat_module_mapping", {}) self._qat_swap_modules(model, additional_qat_module_mapping) self.modules = dict(model.named_modules()) convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph, qconfig_dict) # match the patterns that will get quantized standalone_module_names = prepare_custom_config_dict.get( "standalone_module_name", None) standalone_module_classes = prepare_custom_config_dict.get( "standalone_module_class", None) custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") matches = self._find_matches(model.graph, self.modules, self.patterns, standalone_module_names, standalone_module_classes, custom_module_classes) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuantizeHandler object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env: Dict[Any, Any] = {} observed_graph = Graph() observed_node_names_set = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) # indexes for the inputs that needs to be observed standalone_module_observed_input_idxs = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) get_new_observer_name = get_new_attr_name_with_prefix( 'activation_post_process_') model_device = assert_and_get_unique_device(model) def insert_observer(node, observer): """Insert observer for node by modifying the observed_graph and attach observer module to the model Args: node: Node observer: observer/fake_quantize module instance """ # respect device affinity when adding observers if model_device: observer.to(model_device) # add observer module as attribute prefix = node.name + '_activation_post_process_' get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) # put observer instance activation_post_process map assert self.activation_post_process_map is not None self.activation_post_process_map[node.name] = observer # insert observer call env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node), ), {}) observed_node_names_set.add(node.name) def insert_observer_for_special_module(quantize_handler): """ Insert observer for custom module and standalone module Returns: standalone_module_input_idxs: the indexs for inputs that needs to be observed by parent module """ standalone_module_input_idxs = None assert self.modules is not None if isinstance(quantize_handler, CustomModuleQuantizeHandler): custom_module = self.modules[node.target] custom_module_class_mapping = prepare_custom_config_dict.get( "float_to_observed_custom_module_class", {}) observed_custom_module_class = \ get_swapped_custom_module_class( custom_module, custom_module_class_mapping, qconfig) observed_custom_module = \ observed_custom_module_class.from_float(custom_module) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_custom_module) elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler): # observe standalone module standalone_module = self.modules[node.target] prepare = \ torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore observed_standalone_module = \ prepare(standalone_module, {"": qconfig}) observed_standalone_module.qconfig = qconfig standalone_module_input_idxs = observed_standalone_module.\ _standalone_module_observed_input_idxs observed_standalone_module = mark_observed_standalone_module( observed_standalone_module) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_standalone_module) self.modules[node.target] = observed_standalone_module return standalone_module_input_idxs def insert_observer_for_output_of_the_node( node, quantize_handler, qconfig, standalone_module_input_idxs): """ Insert observer/fake_quantize module for output of the observed module if needed """ # don't need to insert observer for output if activation does not # need to be statically quantized assert self.modules is not None if activation_is_statically_quantized(qconfig): if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \ and model.training: # we only insert fake quantize module in qat assert pattern is not None activation_post_process_ctr = \ get_default_output_activation_post_process_map().get( pattern, None) assert activation_post_process_ctr is not None, \ "activation_post_process constructor not provided " + \ "for pattern:" + str(pattern) insert_observer(node, activation_post_process_ctr()) elif (isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and not model.training) or \ isinstance(quantize_handler, CopyNode): # inserting observers for output of observed module, or # mark the output as observed assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed_node_names_set elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed_node_names_set.add(node.name) elif ((isinstance(quantize_handler, Add) or isinstance(quantize_handler, Mul)) and quantize_handler.num_node_args == 1): assert matched_nodes is not None input_node = matched_nodes[ -1] # first node in the sequence def input_is_observed(arg): return (isinstance(arg, Node) and arg.name in observed_node_names_set) # This is checking if one of the argument of add/mul # is an observed node # If both of the inputs are number, # we will not consider the output to be observed if (input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1])): observed_node_names_set.add(node.name) elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler): assert node.op == 'call_module' output_is_observed = \ self.modules[node.target]._output_is_observed if output_is_observed: observed_node_names_set.add(node.name) elif (quantize_handler.all_node_args and input_output_observed(quantize_handler)): # observer for outputs new_observer = qconfig.activation() insert_observer(node, new_observer) # insert observer for input of standalone module if standalone_module_input_idxs is not None: for idx in standalone_module_input_idxs: if node.args[idx].name not in observed_node_names_set: new_observer = qconfig.activation() insert_observer(node.args[idx], new_observer) def insert_observer_for_input_arg_of_observed_node(arg): """ Input: arg: input arg node for another observed node, e.g. input activaiton for functional linear node """ if node.name not in observed_node_names_set and node.name in quants: if is_standalone_module and node.name in graph_inputs: # we'll insert observer for input of standalone module # in parent graph standalone_module_observed_input_idxs.append( graph_inputs.index(node.name)) return _, activation_post_process_ctr = quants[node.name] if activation_post_process_ctr is not None: insert_observer(node, activation_post_process_ctr()) result_node: Optional[Node] = None for node in model.graph.nodes: if node.op == 'output': observed_graph.output(load_arg(node.args[0])) result_node = node continue if node.name in observed_node_names_set: continue root_node, matched_nodes, pattern, obj, qconfig = matches.get( node.name, (None, None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) # index for input of custom module that needs to be observed in # parent if qconfig is not None: standalone_module_input_idxs = \ insert_observer_for_special_module(obj) insert_observer_for_output_of_the_node( node, obj, qconfig, standalone_module_input_idxs) else: env[node.name] = observed_graph.node_copy(node, load_arg) insert_observer_for_input_arg_of_observed_node(node) model = GraphModule(model, observed_graph) self.save_state(model) model = mark_observed_module(model) if is_standalone_module: assert result_node is not None assert isinstance(result_node.args[0], Node), \ 'standalone module returning dict is not yet supported' # indicator for whether output is observed or not. # This used for correctly quantize standalone modules output_is_observed = \ result_node.args[0].name in observed_node_names_set model._standalone_module_observed_input_idxs = \ standalone_module_observed_input_idxs model._output_is_observed = output_is_observed return model