def compare_prepare_convert_qconfig_dict( prepare_qconfig_dict: Dict[str, Dict[Any, Any]], convert_qconfig_dict: Dict[str, Dict[Any, Any]]) -> None: r""" Compare the qconfig_dict passed in convert to the one from prepare and check the values Args: `prepare_qconfig_dict`: configuration dictionary for prepare quantization step `convert_qconfig_dict`: configuration dictionary for convert quantization step """ prepare_keys = prepare_qconfig_dict.keys() convert_keys = convert_qconfig_dict.keys() for k in prepare_keys: if k == '': assert k in convert_qconfig_dict, "Missing key {} from convert qconfig_dict when it was present in prepare".format( k) assert ( convert_qconfig_dict[k] is None or qconfig_equals( prepare_qconfig_dict[k], convert_qconfig_dict[k]) ), ( # type: ignore[arg-type] "Expected convert qconfig_dict have the same qconfig as prepare qconfig_dict or None." "Updated qconfig {} to {} for key {}".format( prepare_qconfig_dict[k], convert_qconfig_dict[k], k)) elif k in ['object_type', 'module_name', 'module_namr_regex']: for name, qconfig in prepare_qconfig_dict[k].items(): assert name in convert_qconfig_dict[ k], "Missing key {} {} from convert qconfig_dict \ when it was present in prepare".format(k, name) assert convert_qconfig_dict[k][name] is None \ or qconfig_equals(prepare_qconfig_dict[k][name], convert_qconfig_dict[k][name]), \ "Expected convert qconfig_dict have the same qconfig as prepare qconfig_dict or None. \ Updated qconfig {} to {} for key {} {}" .format(prepare_qconfig_dict[k], convert_qconfig_dict[k], k, name) else: assert "Unsupported key in convert_qconfig_dict {}".format(k)
def compare_prepare_convert_qconfig_mappings( prepare_qconfig_mapping: QConfigMapping, convert_qconfig_mapping: QConfigMapping): r""" Compare the qconfig_mapping passed in convert to the one from prepare and check the values Args: `prepare_qconfig_mapping`: configuration for prepare quantization step `convert_qconfig_mapping`: configuration for convert quantization step """ assert qconfig_equals(prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig), \ "Expected global qconfigs to be the same in the prepare and convert quantization configs" prepare_dicts: List[OrderedDict] = [ prepare_qconfig_mapping.object_type_qconfigs, prepare_qconfig_mapping.module_name_qconfigs, prepare_qconfig_mapping.module_name_regex_qconfigs, ] convert_dicts: List[OrderedDict] = [ convert_qconfig_mapping.object_type_qconfigs, convert_qconfig_mapping.module_name_qconfigs, convert_qconfig_mapping.module_name_regex_qconfigs, ] dict_names = [ OBJECT_TYPE_DICT_KEY, MODULE_NAME_DICT_KEY, MODULE_NAME_REGEX_DICT_KEY ] for i in range(len(prepare_dicts)): for name, qconfig in prepare_dicts[i].items(): assert name in convert_dicts[ i], "Missing key {} {} in convert QConfigMapping \ when it was present in prepare".format(dict_names[i], name) assert convert_dicts[i][name] is None \ or qconfig_equals(prepare_dicts[i][name], convert_dicts[i][name]), \ "Expected convert QConfigMapping to have the same qconfig as prepare for key {} {}; \ prepare: {}; convert: {}" .format(dict_names[i], name, prepare_dicts[i][name], convert_dicts[i][name])
def test_embedding_qat_qconfig_equal(self): # Embedding QAT uses a NoopObserver class for activation, # and a FakeQuant for weight, make sure that qconfig comparison # functions properly for a mix of partial function and class in # qconfig. model = ManualEmbeddingBagLinear().train() model = prepare_qat(model) self.assertTrue( qconfig_equals(model.emb.qconfig, default_embedding_qat_qconfig))
def update_qconfig_for_fusion( model: GraphModule, qconfig_dict: Any, ) -> Any: """ Update the qconfig_dict to account for fused modules such as LinearReLU. """ object_type_dict = qconfig_dict.get("object_type", None) if object_type_dict is None: return qconfig_dict modules = dict(model.named_modules()) for node in model.graph.nodes: if node.op == 'call_module' and node.target in modules: maybe_fused_module = modules[str(node.target)] if not isinstance(maybe_fused_module, _FusedModule): continue ops = list(maybe_fused_module._modules.values()) fused_qconfig = object_type_dict.get(type(ops[0]), None) # Raise an error if the modules in the fused module have # different qconfigs specified in the qconfig_dict # TODO: currently it only works for modules, # need to make this work for torch.nn.functional.relu # TODO: currently it only works for object_type configurations, # ideally it should work for different types of configurations, # maybe we want to redesign this part for op in ops[1:]: if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig): raise LookupError( "During fusion, we need to specify the same " + f"qconfigs for all module types in {type(maybe_fused_module)} " + f"offending type: {type(op)}") if fused_qconfig is not None: object_type_dict[type(maybe_fused_module)] = fused_qconfig return qconfig_dict
def update_qconfig_for_fusion( model: GraphModule, qconfig_dict: Any, ) -> Any: """ Update the qconfig_dict to account for fused modules such as LinearReLU. """ object_type_dict = qconfig_dict.get("object_type", None) if object_type_dict is None: return qconfig_dict modules = dict(model.named_modules()) for node in model.graph.nodes: if node.op == 'call_module' and node.target in modules: module_type = type(modules[str(node.target)]) if module_type not in list( DEFAULT_OP_LIST_TO_FUSER_METHOD.values()): continue for ops, fuser in DEFAULT_OP_LIST_TO_FUSER_METHOD.items(): if module_type == fuser: fused_qconfig = object_type_dict.get(ops[0], None) # Raise an error if the modules in the fused module have # different qconfigs specified in the qconfig_dict for op in ops: if not qconfig_equals(object_type_dict.get(op, None), fused_qconfig): raise LookupError( "During fusion, we need to specify the same " + f"qconfigs for both modules in {module_type}.") if fused_qconfig is not None: object_type_dict[module_type] = fused_qconfig return qconfig_dict