def _propagate_qconfig_helper(module, qconfig_dict, qconfig_parent=None, prefix=''): r"""This is a helper function for `propagate_qconfig_` Args: module: input module qconfig_dict: dictionary that maps from name of submodule to quantization configuration qconfig_parent: quantization config of parent module, we will fallback to this config when there is no specified config for current module prefix: corresponding prefix of the current module, used as key in qconfig_dict Return: None, module is modified inplace with qconfig attached """ module_qconfig = qconfig_dict.get(type(module), qconfig_parent) module_qconfig = qconfig_dict.get(prefix, module_qconfig) module_qconfig = getattr(module, 'qconfig', module_qconfig) torch.ao.quantization.qconfig.assert_valid_qconfig(module_qconfig, module) qconfig_with_device_check = add_module_to_qconfig_obs_ctr( module_qconfig, module) module.qconfig = qconfig_with_device_check for name, child in module.named_children(): module_prefix = prefix + '.' + name if prefix else name _propagate_qconfig_helper(child, qconfig_dict, qconfig_with_device_check, module_prefix)
def _propagate_qconfig_helper(module, qconfig_dict, qconfig_parent=None, prefix='', prepare_custom_config_dict=None): r"""This is a helper function for `propagate_qconfig_` Args: module: input module qconfig_dict: dictionary that maps from name of submodule to quantization configuration qconfig_parent: quantization config of parent module, we will fallback to this config when there is no specified config for current module prefix: corresponding prefix of the current module, used as key in qconfig_dict prepare_custom_config_dict: dictionary for custom handling of modules see docs for :func:`~torch.ao.quantization.prepare_fx` Return: None, module is modified inplace with qconfig attached """ module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent) module_qconfig = qconfig_dict.get(prefix, module_qconfig) module_qconfig = getattr(module, 'qconfig', module_qconfig) torch.ao.quantization.qconfig.assert_valid_qconfig(module_qconfig, module) qconfig_with_device_check = add_module_to_qconfig_obs_ctr( module_qconfig, module) module.qconfig = qconfig_with_device_check for name, child in module.named_children(): module_prefix = prefix + '.' + name if prefix else name # do no not propagate qconfig to child if child is non traceable if prepare_custom_config_dict is None or not ( name in prepare_custom_config_dict.get( "non_traceable_module_name", []) or type(child) in prepare_custom_config_dict.get( "non_traceable_module_class", [])): _propagate_qconfig_helper(child, qconfig_dict, qconfig_with_device_check, module_prefix)
def generate_qconfig_map( root: torch.nn.Module, modules: Dict[str, torch.nn.Module], input_graph: Graph, qconfig_dict: Any, node_name_to_scope: Dict[str, Tuple[str, type]]) -> Dict[str, QConfigAny]: global_qconfig = qconfig_dict.get("", None) qconfig_map = dict() # example: # # {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...} # # meaning in submodule 'foo.bar', we have seen 0 F.linear and # 1 F.conv2d invocations so far. submodule_to_object_type_to_cur_idx: Dict[str, Dict[Callable, int]] = \ defaultdict(lambda: defaultdict(int)) for node in input_graph.nodes: qconfig = None if node.op == "get_attr": module_name, _ = _parent_name(node.target) qconfig = maybe_adjust_qconfig_for_module_type_or_name( qconfig_dict, type(modules[module_name]), module_name, global_qconfig) qconfig_with_device_check = add_module_to_qconfig_obs_ctr( qconfig, modules.get(node.target, None)) elif node.op == "call_function": # precedence: module_name_qconfig # > function_qconfig > global_qconfig # module_name takes precedence over function qconfig function_qconfig = get_object_type_qconfig(qconfig_dict, node.target, global_qconfig) module_path, module_type = node_name_to_scope[node.name] qconfig = maybe_adjust_qconfig_for_module_type_or_name( qconfig_dict, module_type, module_path, function_qconfig) cur_object_type_idx = \ submodule_to_object_type_to_cur_idx[module_path][node.target] submodule_to_object_type_to_cur_idx[module_path][node.target] += 1 qconfig = maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_dict, module_path, node.target, cur_object_type_idx, qconfig) qconfig_with_device_check = add_module_to_qconfig_obs_ctr( qconfig, modules.get(node.target, None)) elif node.op == "call_method": module_path, module_type = node_name_to_scope[node.name] # first use node.target (string) to get the qconfig # this is to support configs like # "object_type": [("reshpe", qconfig)] qconfig = maybe_adjust_qconfig_for_module_type_or_name( qconfig_dict, node.target, module_path, global_qconfig) # if there is no special config for the method, we'll fall back to the # config for the module that contains the call_method node qconfig = maybe_adjust_qconfig_for_module_type_or_name( qconfig_dict, module_type, module_path, qconfig) # currently call_method does not support modifying qconfig # by order, we can add this later if it is needed. qconfig_with_device_check = add_module_to_qconfig_obs_ctr( qconfig, modules.get(node.target, None)) elif node.op == 'call_module': # if the node is an observer, just continue - don't add it to the qconfig_map if is_activation_post_process(modules[node.target]): continue qconfig = maybe_adjust_qconfig_for_module_type_or_name( qconfig_dict, type(modules[node.target]), node.target, global_qconfig) module_path, module_type = node_name_to_scope[node.name] # Note: for call_module, the module_path is the current module's name. # to meaningfully count invocations, we need to count them in the parent # module. parent_name, _ = _parent_name(module_path) cur_object_type_idx = \ submodule_to_object_type_to_cur_idx[parent_name][module_type] submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1 qconfig = maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_dict, parent_name, module_type, cur_object_type_idx, qconfig) qconfig_with_device_check = add_module_to_qconfig_obs_ctr( qconfig, modules.get(node.target, None)) # regex is not supported eager mode propagate_qconfig_, we'll # need to set the qconfig explicitly here in case regex # is used modules[node.target].qconfig = qconfig_with_device_check else: qconfig_with_device_check = None qconfig_map[node.name] = qconfig_with_device_check return qconfig_map