def get_kd_config(config: NNCFConfig, kd_type='mse', scale=1, temperature=None) -> NNCFConfig: if isinstance(config.get('compression', {}), dict): config['compression'] = [config['compression']] if config.get( 'compression', None) is not None else [] kd_algo_dict = { 'algorithm': 'knowledge_distillation', 'type': kd_type, 'scale': scale } if temperature is not None: kd_algo_dict['temperature'] = temperature config['compression'].append(kd_algo_dict) return config
def __init__(self, config: NNCFConfig, should_init: bool = True): super().__init__(config, should_init) self.quantize_inputs = self._algo_config.get('quantize_inputs', True) self.quantize_outputs = self._algo_config.get('quantize_outputs', False) self._overflow_fix = self._algo_config.get('overflow_fix', 'enable') self._target_device = config.get('target_device', 'ANY') algo_config = self._get_algo_specific_config_section() if self._target_device == 'VPU' and 'preset' in algo_config: raise RuntimeError( "The VPU target device does not support presets.") self.global_quantizer_constraints = {} self.ignored_scopes_per_group = {} self.target_scopes_per_group = {} self._op_names = [] for quantizer_group in QuantizerGroup: self._parse_group_params(self._algo_config, quantizer_group) if self.should_init: self._parse_init_params() self._range_initializer = None self._bn_adaptation = None self._quantizer_setup = None self.hw_config = None if self._target_device != "TRIAL": hw_config_type = HWConfigType.from_str( HW_CONFIG_TYPE_TARGET_DEVICE_MAP[self._target_device]) hw_config_path = TFHWConfig.get_path_to_hw_config(hw_config_type) self.hw_config = TFHWConfig.from_json(hw_config_path)
def get_multipliers_from_config(config: NNCFConfig) -> Dict[str, float]: algo_to_multipliers = {} algorithms = get_config_algorithms(config) global_multiplier = config.get('compression_lr_multiplier', 1) for algo in algorithms: algo_name = algo['algorithm'] algo_to_multipliers[algo_name] = algo.get('compression_lr_multiplier', global_multiplier) return algo_to_multipliers
def is_accuracy_aware_training(nncf_config: NNCFConfig) -> bool: if nncf_config.get("accuracy_aware_training") is not None: return True return False