Ejemplo n.º 1
0
def create_compressed_model(model: tf.keras.Model,
                            config: NNCFConfig,
                            compression_state: Optional[Dict[str, Any]] = None) \
        -> Tuple[CompressionAlgorithmController, tf.keras.Model]:
    """
    The main function used to produce a model ready for compression fine-tuning
    from an original TensorFlow Keras model and a configuration object.

    :param model: The original model. Should have its parameters already loaded
        from a checkpoint or another source.
    :param config: A configuration object used to determine the exact compression
        modifications to be applied to the model.
    :param compression_state: compression state to unambiguously restore the compressed model.
        Includes builder and controller states. If it is specified, trainable parameter initialization will be skipped
        during building.
    :return: A tuple (compression_ctrl, compressed_model) where
        - compression_ctrl: The controller of the compression algorithm.
        - compressed_model: The model with additional modifications
            necessary to enable algorithm-specific compression during fine-tuning.
    """
    model = get_built_model(model, config)
    original_model_accuracy = None

    if is_accuracy_aware_training(config):
        if config.has_extra_struct(ModelEvaluationArgs):
            evaluation_args = config.get_extra_struct(ModelEvaluationArgs)
            original_model_accuracy = evaluation_args.eval_fn(model)

    builder = create_compression_algorithm_builder(
        config, should_init=not compression_state)
    if compression_state:
        builder.load_state(compression_state[BaseController.BUILDER_STATE])
    compressed_model = builder.apply_to(model)
    compression_ctrl = builder.build_controller(compressed_model)
    compressed_model.original_model_accuracy = original_model_accuracy
    if isinstance(compressed_model, tf.keras.Model):
        compressed_model.accuracy_aware_fit = types.MethodType(
            accuracy_aware_fit, compressed_model)
    return compression_ctrl, compressed_model
Ejemplo n.º 2
0
    def __init__(self, target_model: NNCFNetwork, prunable_types: List[str],
                 pruned_module_groups: Clusterization[PrunedModuleInfo],
                 pruned_norms_operators: List[Tuple[NNCFNode,
                                                    FilterPruningMask,
                                                    torch.nn.Module]],
                 config: NNCFConfig):
        #pylint:disable=too-many-statements
        super().__init__(target_model, prunable_types, pruned_module_groups,
                         config)
        params = self.pruning_config.get('params', {})
        self._pruned_norms_operators = pruned_norms_operators
        self.frozen = False
        self._pruning_level = 0
        self.pruning_init = self.pruning_config.get('pruning_init', 0)
        self.pruning_quota = 0.9
        self.normalize_weights = True

        self._init_module_channels_and_shapes()
        self.pruning_quotas = {}
        self.nodes_flops = {}  # type: Dict[NNCFNodeName, int]
        self.nodes_params_num = {}  # type: Dict[NNCFNodeName, int]
        self.next_nodes = {}  # type: Dict[int, List[NNCFNodeName]]
        self._init_pruned_modules_params()
        self.flops_count_init()
        self.full_flops = sum(self.nodes_flops.values())
        self.current_flops = self.full_flops
        self.full_params_num = sum(self.nodes_params_num.values())
        self.current_params_num = self.full_params_num
        self.full_filters_num = count_filters_num(
            self._model.get_original_graph(), GENERAL_CONV_LAYER_METATYPES)
        self.current_filters_num = self.full_filters_num
        self._pruned_layers_num = len(
            self.pruned_module_groups_info.get_all_nodes())
        self._prunable_layers_num = len(
            self._model.get_graph().get_nodes_by_types(self._prunable_types))
        self._max_prunable_flops, self._max_prunable_params =\
            self._calculate_flops_and_weights_in_uniformly_pruned_model(1.)

        self.weights_normalizer = tensor_l2_normalizer  # for all weights in common case
        self.filter_importance = FILTER_IMPORTANCE_FUNCTIONS.get(
            params.get('filter_importance', 'L2'))
        self.ranking_type = params.get('interlayer_ranking_type',
                                       'unweighted_ranking')
        self.all_weights = params.get("all_weights", False)
        scheduler_cls = PRUNING_SCHEDULERS.get(
            params.get('schedule', 'exponential'))
        self._scheduler = scheduler_cls(self, params)

        if self.ranking_type == 'learned_ranking':
            # In case of learned_ranking ranking type weights shouldn't be normalized
            self.normalize_weights = False
            if params.get('load_ranking_coeffs_path'):
                coeffs_path = params.get('load_ranking_coeffs_path')
                nncf_logger.info(
                    'Loading ranking coefficients from file {}'.format(
                        coeffs_path))
                try:
                    with open(coeffs_path, 'r',
                              encoding='utf8') as coeffs_file:
                        loaded_coeffs = json.load(coeffs_file)
                except (ValueError, FileNotFoundError) as err:
                    raise Exception(
                        'Can\'t load json with ranking coefficients. Please, check format of json file '
                        'and path to the file.') from err
                ranking_coeffs = {
                    key: tuple(loaded_coeffs[key])
                    for key in loaded_coeffs
                }
                nncf_logger.info(
                    'Loaded ranking coefficients = {}'.format(ranking_coeffs))
                self.ranking_coeffs = ranking_coeffs
            else:
                # Ranking can't be trained without registered init struct LeGRInitArgs
                if not config.has_extra_struct(LeGRInitArgs):
                    raise Exception(
                        'Please, register LeGRInitArgs via register_default_init_args function.'
                    )
                # Wrapping model for parallelization
                distributed_wrapping_init_args = config.get_extra_struct(
                    DistributedCallbacksArgs)
                target_model = distributed_wrapping_init_args.wrap_model(
                    target_model)
                legr_init_args = config.get_extra_struct(LeGRInitArgs)
                legr_params = params.get("legr_params", {})
                if 'max_pruning' not in legr_params:
                    legr_params['max_pruning'] = self._scheduler.target_level
                self.legr = LeGR(self, target_model, legr_init_args,
                                 **legr_params)
                self.ranking_coeffs = self.legr.train_global_ranking()
                nncf_logger.info('Trained ranking coefficients = {}'.format(
                    self.ranking_coeffs))
                # Unwrapping parallelized model
                target_model = distributed_wrapping_init_args.unwrap_model(
                    target_model)
        else:
            self.ranking_coeffs = {
                node.node_name: (1, 0)
                for node in self.pruned_module_groups_info.get_all_nodes()
            }

        # Saving ranking coefficients to the specified file
        if params.get('save_ranking_coeffs_path'):
            nncf_logger.info(
                'Saving ranking coefficients to the file {}'.format(
                    params.get('save_ranking_coeffs_path')))
            with open(params.get('save_ranking_coeffs_path'),
                      'w',
                      encoding='utf8') as f:
                json.dump(self.ranking_coeffs, f)

        self.set_pruning_level(self.pruning_init)
        self._bn_adaptation = None