def get_fuser_method(op_list, additional_fuser_method_mapping=None):
    ''' Get fuser method for the given list of module types,
    return None if fuser method does not exist
    '''
    if additional_fuser_method_mapping is None:
        additional_fuser_method_mapping = dict()
    all_mappings = get_combined_dict(DEFAULT_OP_LIST_TO_FUSER_METHOD,
                                     additional_fuser_method_mapping)
    fuser_method = all_mappings.get(op_list, None)
    assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
    return fuser_method
Beispiel #2
0
def get_dynamic_quant_module_class(
        float_module_class: Callable,
        additional_dynamic_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any:
    r"""n Get the dynamically quantized module class corresponding to
    the floating point module class
    """
    if additional_dynamic_quant_mapping is None:
        additional_dynamic_quant_mapping = {}
    all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping)
    dynamic_quant_module_class = all_mappings.get(float_module_class, None)
    assert dynamic_quant_module_class is not None, \
        "Floating point module class {}".format(str(float_module_class)) + \
        " does not have a corresponding quantized module class"
    return copy.deepcopy(dynamic_quant_module_class)
Beispiel #3
0
def get_static_quant_module_class(
        float_module_class: Callable,
        additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None,
        is_reference: bool = False) -> Any:
    r"""n Get the statically quantized module class corresponding to
    the floating point module class
    """
    if additional_static_quant_mapping is None:
        additional_static_quant_mapping = {}
    all_mappings = get_combined_dict(
        DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS if is_reference
        else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
    static_quant_module_class = all_mappings.get(float_module_class, None)
    assert static_quant_module_class is not None, \
        "Floating point module class {}".format(str(float_module_class)) + \
        " does not have a corresponding quantized module class"
    return copy.deepcopy(static_quant_module_class)
def get_native_quant_patterns(
    additional_quant_patterns: Dict[Pattern, QuantizerCls] = None
) -> Dict[Pattern, QuantizerCls]:
    """
    Return a map from pattern to quantize handlers based on the default patterns and the native backend_config_dict.
    The returned map is sorted such that longer patterns will be encountered first when iterating through it.
    """
    patterns = get_default_quant_patterns()
    if additional_quant_patterns is not None:
        patterns = get_combined_dict(patterns, additional_quant_patterns)
    # TODO: currently we just extend the quantize handlers generated from
    # `get_native_backend_config_dict`
    # in the future we can just assign backend_config_dict when everything is defined
    for pattern, quantize_handler in get_pattern_to_quantize_handlers(
            get_native_backend_config_dict()).items():
        patterns[pattern] = quantize_handler
    return sorted_patterns_dict(patterns)