def get_fuser_method(op_list, additional_fuser_method_mapping=None): ''' Get fuser method for the given list of module types, return None if fuser method does not exist ''' if additional_fuser_method_mapping is None: additional_fuser_method_mapping = dict() all_mappings = get_combined_dict(DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping) fuser_method = all_mappings.get(op_list, None) assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list) return fuser_method
def get_dynamic_quant_module_class( float_module_class: Callable, additional_dynamic_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any: r"""n Get the dynamically quantized module class corresponding to the floating point module class """ if additional_dynamic_quant_mapping is None: additional_dynamic_quant_mapping = {} all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping) dynamic_quant_module_class = all_mappings.get(float_module_class, None) assert dynamic_quant_module_class is not None, \ "Floating point module class {}".format(str(float_module_class)) + \ " does not have a corresponding quantized module class" return copy.deepcopy(dynamic_quant_module_class)
def get_static_quant_module_class( float_module_class: Callable, additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None, is_reference: bool = False) -> Any: r"""n Get the statically quantized module class corresponding to the floating point module class """ if additional_static_quant_mapping is None: additional_static_quant_mapping = {} all_mappings = get_combined_dict( DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS if is_reference else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping) static_quant_module_class = all_mappings.get(float_module_class, None) assert static_quant_module_class is not None, \ "Floating point module class {}".format(str(float_module_class)) + \ " does not have a corresponding quantized module class" return copy.deepcopy(static_quant_module_class)
def get_native_quant_patterns( additional_quant_patterns: Dict[Pattern, QuantizerCls] = None ) -> Dict[Pattern, QuantizerCls]: """ Return a map from pattern to quantize handlers based on the default patterns and the native backend_config_dict. The returned map is sorted such that longer patterns will be encountered first when iterating through it. """ patterns = get_default_quant_patterns() if additional_quant_patterns is not None: patterns = get_combined_dict(patterns, additional_quant_patterns) # TODO: currently we just extend the quantize handlers generated from # `get_native_backend_config_dict` # in the future we can just assign backend_config_dict when everything is defined for pattern, quantize_handler in get_pattern_to_quantize_handlers( get_native_backend_config_dict()).items(): patterns[pattern] = quantize_handler return sorted_patterns_dict(patterns)