def create_compressed_model(model: tf.keras.Model, config: NNCFConfig, compression_state: Optional[Dict[str, Any]] = None) \ -> Tuple[CompressionAlgorithmController, tf.keras.Model]: """ The main function used to produce a model ready for compression fine-tuning from an original TensorFlow Keras model and a configuration object. :param model: The original model. Should have its parameters already loaded from a checkpoint or another source. :param config: A configuration object used to determine the exact compression modifications to be applied to the model. :param compression_state: compression state to unambiguously restore the compressed model. Includes builder and controller states. If it is specified, trainable parameter initialization will be skipped during building. :return: A tuple (compression_ctrl, compressed_model) where - compression_ctrl: The controller of the compression algorithm. - compressed_model: The model with additional modifications necessary to enable algorithm-specific compression during fine-tuning. """ model = get_built_model(model, config) original_model_accuracy = None if is_accuracy_aware_training(config): if config.has_extra_struct(ModelEvaluationArgs): evaluation_args = config.get_extra_struct(ModelEvaluationArgs) original_model_accuracy = evaluation_args.eval_fn(model) builder = create_compression_algorithm_builder( config, should_init=not compression_state) if compression_state: builder.load_state(compression_state[BaseController.BUILDER_STATE]) compressed_model = builder.apply_to(model) compression_ctrl = builder.build_controller(compressed_model) compressed_model.original_model_accuracy = original_model_accuracy if isinstance(compressed_model, tf.keras.Model): compressed_model.accuracy_aware_fit = types.MethodType( accuracy_aware_fit, compressed_model) return compression_ctrl, compressed_model
def __init__(self, target_model: NNCFNetwork, prunable_types: List[str], pruned_module_groups: Clusterization[PrunedModuleInfo], pruned_norms_operators: List[Tuple[NNCFNode, FilterPruningMask, torch.nn.Module]], config: NNCFConfig): #pylint:disable=too-many-statements super().__init__(target_model, prunable_types, pruned_module_groups, config) params = self.pruning_config.get('params', {}) self._pruned_norms_operators = pruned_norms_operators self.frozen = False self._pruning_level = 0 self.pruning_init = self.pruning_config.get('pruning_init', 0) self.pruning_quota = 0.9 self.normalize_weights = True self._init_module_channels_and_shapes() self.pruning_quotas = {} self.nodes_flops = {} # type: Dict[NNCFNodeName, int] self.nodes_params_num = {} # type: Dict[NNCFNodeName, int] self.next_nodes = {} # type: Dict[int, List[NNCFNodeName]] self._init_pruned_modules_params() self.flops_count_init() self.full_flops = sum(self.nodes_flops.values()) self.current_flops = self.full_flops self.full_params_num = sum(self.nodes_params_num.values()) self.current_params_num = self.full_params_num self.full_filters_num = count_filters_num( self._model.get_original_graph(), GENERAL_CONV_LAYER_METATYPES) self.current_filters_num = self.full_filters_num self._pruned_layers_num = len( self.pruned_module_groups_info.get_all_nodes()) self._prunable_layers_num = len( self._model.get_graph().get_nodes_by_types(self._prunable_types)) self._max_prunable_flops, self._max_prunable_params =\ self._calculate_flops_and_weights_in_uniformly_pruned_model(1.) self.weights_normalizer = tensor_l2_normalizer # for all weights in common case self.filter_importance = FILTER_IMPORTANCE_FUNCTIONS.get( params.get('filter_importance', 'L2')) self.ranking_type = params.get('interlayer_ranking_type', 'unweighted_ranking') self.all_weights = params.get("all_weights", False) scheduler_cls = PRUNING_SCHEDULERS.get( params.get('schedule', 'exponential')) self._scheduler = scheduler_cls(self, params) if self.ranking_type == 'learned_ranking': # In case of learned_ranking ranking type weights shouldn't be normalized self.normalize_weights = False if params.get('load_ranking_coeffs_path'): coeffs_path = params.get('load_ranking_coeffs_path') nncf_logger.info( 'Loading ranking coefficients from file {}'.format( coeffs_path)) try: with open(coeffs_path, 'r', encoding='utf8') as coeffs_file: loaded_coeffs = json.load(coeffs_file) except (ValueError, FileNotFoundError) as err: raise Exception( 'Can\'t load json with ranking coefficients. Please, check format of json file ' 'and path to the file.') from err ranking_coeffs = { key: tuple(loaded_coeffs[key]) for key in loaded_coeffs } nncf_logger.info( 'Loaded ranking coefficients = {}'.format(ranking_coeffs)) self.ranking_coeffs = ranking_coeffs else: # Ranking can't be trained without registered init struct LeGRInitArgs if not config.has_extra_struct(LeGRInitArgs): raise Exception( 'Please, register LeGRInitArgs via register_default_init_args function.' ) # Wrapping model for parallelization distributed_wrapping_init_args = config.get_extra_struct( DistributedCallbacksArgs) target_model = distributed_wrapping_init_args.wrap_model( target_model) legr_init_args = config.get_extra_struct(LeGRInitArgs) legr_params = params.get("legr_params", {}) if 'max_pruning' not in legr_params: legr_params['max_pruning'] = self._scheduler.target_level self.legr = LeGR(self, target_model, legr_init_args, **legr_params) self.ranking_coeffs = self.legr.train_global_ranking() nncf_logger.info('Trained ranking coefficients = {}'.format( self.ranking_coeffs)) # Unwrapping parallelized model target_model = distributed_wrapping_init_args.unwrap_model( target_model) else: self.ranking_coeffs = { node.node_name: (1, 0) for node in self.pruned_module_groups_info.get_all_nodes() } # Saving ranking coefficients to the specified file if params.get('save_ranking_coeffs_path'): nncf_logger.info( 'Saving ranking coefficients to the file {}'.format( params.get('save_ranking_coeffs_path'))) with open(params.get('save_ranking_coeffs_path'), 'w', encoding='utf8') as f: json.dump(self.ranking_coeffs, f) self.set_pruning_level(self.pruning_init) self._bn_adaptation = None