Ejemplo n.º 1
0
def create_compression_algorithm_builders(
        config: NNCFConfig,
        should_init: bool = True) -> List[CompressionAlgorithmBuilder]:
    compression_config_json_section = config.get('compression', {})
    compression_config_json_section = deepcopy(compression_config_json_section)

    hw_config_type = None
    hw_config_type_str = config.get("hw_config_type")
    if hw_config_type_str is not None:
        hw_config_type = HWConfigType.from_str(config.get("hw_config_type"))
    if isinstance(compression_config_json_section, dict):
        compression_config = NNCFConfig(compression_config_json_section)
        compression_config.register_extra_structs(
            config.get_all_extra_structs_for_copy())
        compression_config["hw_config_type"] = hw_config_type
        return [
            get_compression_algorithm(compression_config)(
                compression_config, should_init=should_init),
        ]
    retval = []
    for algo_config in compression_config_json_section:
        algo_config = NNCFConfig(algo_config)
        algo_config.register_extra_structs(
            config.get_all_extra_structs_for_copy())
        algo_config["hw_config_type"] = hw_config_type
        retval.append(
            get_compression_algorithm(algo_config)(algo_config,
                                                   should_init=should_init))
    return retval
Ejemplo n.º 2
0
def create_compression_algorithm_builders(
        config: NNCFConfig,
        should_init: bool = True) -> List[CompressionAlgorithmBuilder]:
    compression_config_json_section = config.get('compression', {})
    compression_config_json_section = deepcopy(compression_config_json_section)

    hw_config_type = None
    quantizer_setup_type_str = config.get("quantizer_setup_type",
                                          "propagation_based")
    quantizer_setup_type = QuantizerSetupType.from_str(
        quantizer_setup_type_str)
    if quantizer_setup_type == QuantizerSetupType.PROPAGATION_BASED:
        target_device = config.get("target_device", "ANY")
        if target_device != 'TRIAL':
            hw_config_type = HWConfigType.from_str(
                HW_CONFIG_TYPE_TARGET_DEVICE_MAP[target_device])

    if isinstance(compression_config_json_section, dict):
        compression_config = NNCFConfig(compression_config_json_section)
        compression_config.register_extra_structs(
            config.get_all_extra_structs_for_copy())
        compression_config["hw_config_type"] = hw_config_type
        compression_config['quantizer_setup_type'] = quantizer_setup_type
        return [
            get_compression_algorithm(compression_config)(
                compression_config, should_init=should_init),
        ]
    retval = []
    for algo_config in compression_config_json_section:
        algo_config = NNCFConfig(algo_config)
        algo_config.register_extra_structs(
            config.get_all_extra_structs_for_copy())
        algo_config["hw_config_type"] = hw_config_type
        algo_config['quantizer_setup_type'] = quantizer_setup_type
        retval.append(
            get_compression_algorithm(algo_config)(algo_config,
                                                   should_init=should_init))
    return retval