Esempio n. 1
0
    def _set_binary_masks_for_pruned_layers_globally(self,
                                                     pruning_level: float):
        """
        Sets the binary mask values for layer groups according to the global pruning level.
        Filter importance scores in each group are merged into a single global list and a
        threshold value separating the pruning_level proportion of the least important filters
        in the model is calculated. Filters are pruned globally according to the threshold value.
        """
        nncf_logger.debug(
            'Setting new binary masks for all pruned modules together.')
        filter_importances = {}
        wrapped_layers = collect_wrapped_layers(self._model)

        # 0. Remove masks at the elements of the NNCFGraph
        for node in self._original_graph.topological_sort():
            node.data.pop('output_mask', None)

        # 1. Calculate masks
        # a. Calculate importances for all groups of filters
        for group in self._pruned_layer_groups_info.get_all_clusters():
            cumulative_filters_importance = self._calculate_filters_importance_in_group(
                group)
            filter_importances[group.id] = cumulative_filters_importance

        # b. Calculate one threshold for all weights
        importances = tf.concat(list(filter_importances.values()), 0)
        threshold = sorted(importances)[int(pruning_level *
                                            importances.shape[0])]

        # c. Initialize masks
        for group in self._pruned_layer_groups_info.get_all_clusters():
            filter_mask = calculate_binary_mask(filter_importances[group.id],
                                                threshold)
            for node in group.elements:
                nncf_node = self._original_graph.get_node_by_id(
                    node.nncf_node_id)
                nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask)

        # 2. Propagate masks across the graph
        mask_propagator = MaskPropagationAlgorithm(
            self._original_graph, TF_PRUNING_OPERATOR_METATYPES,
            TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # 3. Apply masks to the model
        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes if layer.name == n.layer_name
            ][0]
            if nncf_node.data['output_mask'] is not None:
                self._set_operation_masks([layer],
                                          nncf_node.data['output_mask'].tensor)

        # Calculate actual flops with new masks
        self._update_benchmark_statistics()
Esempio n. 2
0
    def _set_binary_masks_for_pruned_layers_groupwise(self,
                                                      pruning_level: float):
        nncf_logger.debug('Setting new binary masks for pruned layers.')
        wrapped_layers = collect_wrapped_layers(self._model)

        # 0. Removing masks at the elements of the NNCFGraph
        for node in self._original_graph.topological_sort():
            node.data.pop('output_mask', None)

        # 1. Calculate masks
        for group in self._pruned_layer_groups_info.get_all_clusters():
            # a. Calculate the cumulative importance for all filters in the group
            cumulative_filters_importance = self._calculate_filters_importance_in_group(
                group)
            filters_num = len(cumulative_filters_importance)

            # b. Calculate threshold
            num_of_sparse_elems = get_rounded_pruned_element_number(
                cumulative_filters_importance.shape[0], pruning_level)
            threshold = sorted(cumulative_filters_importance)[min(
                num_of_sparse_elems, filters_num - 1)]

            # c. Initialize masks
            filter_mask = calculate_binary_mask(cumulative_filters_importance,
                                                threshold)
            for node in group.elements:
                nncf_node = self._original_graph.get_node_by_id(
                    node.nncf_node_id)
                nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask)

        # 2. Propagating masks across the graph
        mask_propagator = MaskPropagationAlgorithm(
            self._original_graph, TF_PRUNING_OPERATOR_METATYPES,
            TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # 3. Apply masks to the model
        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes if layer.name == n.layer_name
            ][0]
            if nncf_node.data['output_mask'] is not None:
                self._set_operation_masks([layer],
                                          nncf_node.data['output_mask'].tensor)

        # Calculate actual flops and weights number with new masks
        self._update_benchmark_statistics()
Esempio n. 3
0
    def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLayout:
        """
        Computes necessary model transformations (pruning mask insertions) to enable pruning.

        :param model: The original uncompressed model.
        :return: The instance of the `TransformationLayout` class containing
            a list of pruning mask insertions.
        """
        converter = TFModelConverterFactory.create(model)
        self._graph = converter.convert()
        groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(self._graph)

        transformations = TFTransformationLayout()
        shared_layers = set()

        self._pruned_layer_groups_info = Clusterization[PrunedLayerInfo](lambda x: x.layer_name)

        for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.elements:
                layer_name = get_layer_identifier(node)
                layer = model.get_layer(layer_name)
                group_minfos.append(PrunedLayerInfo(node.node_name, layer_name, node.node_id,
                                                    is_prunable_depthwise_conv(node)))

                # Add output_mask to elements to run mask_propagation
                # and detect spec_nodes that will be pruned.
                # It should be done for all elements of shared layer.
                node.data['output_mask'] = TFNNCFTensor(tf.ones(node.layer_attributes.out_channels))
                if layer_name in shared_layers:
                    continue
                if node.is_shared():
                    shared_layers.add(layer_name)
                # Check that we need to prune weights in this op
                assert self._is_pruned_layer(layer)
                nncf_logger.info('Adding Weight Pruner in: %s', layer_name)

                _, layer_info = converter.get_layer_info_for_node(node.node_name)
                for weight_def in node.metatype.weight_definitions:
                    transformations.register(
                        self._get_insertion_command_binary_mask(
                            layer_info.layer_name, weight_def.weight_attr_name)
                    )
                if node.metatype.bias_attr_name is not None and \
                        getattr(layer, node.metatype.bias_attr_name) is not None:
                    transformations.register(
                        self._get_insertion_command_binary_mask(
                            layer_info.layer_name, node.metatype.bias_attr_name)
                    )

            cluster = Cluster[PrunedLayerInfo](i, group_minfos, [n.node_id for n in group.elements])
            self._pruned_layer_groups_info.add_cluster(cluster)

        # Propagating masks across the graph to detect spec_nodes that will be pruned
        mask_propagator = MaskPropagationAlgorithm(self._graph, TF_PRUNING_OPERATOR_METATYPES,
                                                   TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # Add masks for all spec modules, because prunable batchnorm layers can be determined
        # at the moment of mask propagation
        types_spec_layers = [TFBatchNormalizationLayerMetatype] \
            if self._prune_batch_norms else []

        spec_nodes = self._graph.get_nodes_by_metatypes(types_spec_layers)
        for spec_node in spec_nodes:
            layer_name = get_layer_identifier(spec_node)
            layer = model.get_layer(layer_name)
            if spec_node.data['output_mask'] is None:
                # Skip elements that will not be pruned
                continue
            if layer_name in shared_layers:
                continue
            if spec_node.is_shared():
                shared_layers.add(layer_name)
            nncf_logger.info('Adding Weight Pruner in: %s', layer_name)

            _, layer_info = converter.get_layer_info_for_node(spec_node.node_name)
            for weight_def in spec_node.metatype.weight_definitions:
                if spec_node.metatype is TFBatchNormalizationLayerMetatype \
                        and not layer.scale and weight_def.weight_attr_name == 'gamma':
                    nncf_logger.debug('Fused gamma parameter encountered in BatchNormalization layer. '
                                      'Do not add mask to it.')
                    continue

                transformations.register(
                    self._get_insertion_command_binary_mask(
                        layer_info.layer_name, weight_def.weight_attr_name)
                )
            transformations.register(
                self._get_insertion_command_binary_mask(
                    layer_info.layer_name, spec_node.metatype.bias_attr_name)
            )
        return transformations
Esempio n. 4
0
 def repeat(cls, tensor: NNCFTensor, repeats: int) -> NNCFTensor:
     ret_tensor = tf.repeat(tensor, repeats=repeats)
     return TFNNCFTensor(ret_tensor)
Esempio n. 5
0
 def ones(cls, shape: Union[int, List[int]],
          device: tf.device) -> NNCFTensor:
     with tf.device(device):
         return TFNNCFTensor(tf.ones(shape))
Esempio n. 6
0
 def concatenate(cls, tensors: List[NNCFTensor], axis: int) -> NNCFTensor:
     # pylint: disable=E1120,E1123
     ret_tensor = tf.concat([t.tensor for t in tensors], axis=axis)
     return TFNNCFTensor(ret_tensor)
Esempio n. 7
0
    def _set_binary_masks_for_pruned_modules_globally_by_flops_target(
            self, target_flops_pruning_level: float):
        """
        Prunes least important filters one-by-one until target FLOPs pruning level is achieved.
        Filters are sorted by filter importance score.
        """
        nncf_logger.debug('Setting new binary masks for pruned layers.')
        target_flops = self.full_flops * (1 - target_flops_pruning_level)
        wrapped_layers = collect_wrapped_layers(self._model)
        masks = {}

        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes if layer.name == n.layer_name
            ][0]
            nncf_node.data['output_mask'] = TFNNCFTensor(
                tf.ones(get_filters_num(layer)))

        # 1. Calculate importances for all groups of filters. Initialize masks.
        filter_importances = []
        group_indexes = []
        filter_indexes = []
        for group in self._pruned_layer_groups_info.get_all_clusters():
            cumulative_filters_importance = self._calculate_filters_importance_in_group(
                group)
            filter_importances.extend(cumulative_filters_importance)
            filters_num = len(cumulative_filters_importance)
            group_indexes.extend([group.id] * filters_num)
            filter_indexes.extend(range(filters_num))
            masks[group.id] = tf.ones(filters_num)

        # 2.
        tmp_in_channels = self._layers_in_channels.copy()
        tmp_out_channels = self._layers_out_channels.copy()
        sorted_importances = sorted(zip(filter_importances, group_indexes,
                                        filter_indexes),
                                    key=lambda x: x[0])
        for _, group_id, filter_index in sorted_importances:
            if self._pruning_quotas[group_id] == 0:
                continue
            masks[group_id] = tf.tensor_scatter_nd_update(
                masks[group_id], [[filter_index]], [0])
            self._pruning_quotas[group_id] -= 1

            # Update input/output shapes of pruned elements
            group = self._pruned_layer_groups_info.get_cluster_by_id(group_id)
            for node in group.elements:
                tmp_out_channels[node.node_name] -= 1
                if node.is_depthwise:
                    tmp_in_channels[node.node_name] -= 1

            for node_name in self._next_nodes[group_id]:
                tmp_in_channels[node_name] -= 1

            flops, params_num = count_flops_and_weights(
                self._original_graph,
                self._layers_in_shapes,
                self._layers_out_shapes,
                input_channels=tmp_in_channels,
                output_channels=tmp_out_channels,
                conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
                linear_op_metatypes=LINEAR_LAYER_METATYPES)
            if flops <= target_flops:
                # 3. Add masks to the graph and propagate them
                for group in self._pruned_layer_groups_info.get_all_clusters():
                    for node in group.elements:
                        nncf_node = self._original_graph.get_node_by_id(
                            node.nncf_node_id)
                        nncf_node.data['output_mask'] = TFNNCFTensor(
                            masks[group.id])

                mask_propagator = MaskPropagationAlgorithm(
                    self._original_graph, TF_PRUNING_OPERATOR_METATYPES,
                    TFNNCFPruningTensorProcessor)
                mask_propagator.mask_propagation()

                # 4. Set binary masks to the model
                self.current_flops = flops
                self.current_params_num = params_num
                nncf_sorted_nodes = self._original_graph.topological_sort()
                for layer in wrapped_layers:
                    nncf_node = [
                        n for n in nncf_sorted_nodes
                        if layer.name == n.layer_name
                    ][0]
                    if nncf_node.data['output_mask'] is not None:
                        self._set_operation_masks(
                            [layer], nncf_node.data['output_mask'].tensor)
                return
        raise RuntimeError(
            f'Unable to prune model to required flops pruning level:'
            f' {target_flops_pruning_level}')
Esempio n. 8
0
 def unstack(x: NNCFTensor, axis: int = 0) -> List[NNCFTensor]:
     tensor_list = tf.unstack(x.tensor, axis=axis)
     return [TFNNCFTensor(t) for t in tensor_list]
Esempio n. 9
0
 def _register_input(self, x: tf.Tensor):
     self._register_input_common(TFNNCFTensor(x))
Esempio n. 10
0
 def stack(x: Union[List[tf.Tensor], Deque[tf.Tensor]],
           axis: int = 0) -> NNCFTensor:
     x = [t.tensor for t in x]
     return TFNNCFTensor(tf.stack(x, axis=axis))
Esempio n. 11
0
 def mean(x: NNCFTensor, axis: Union[int, tuple, list]) -> NNCFTensor:
     return TFNNCFTensor(tf.math.reduce_mean(x.tensor, axis=axis))
Esempio n. 12
0
 def max(x1: tf.Tensor, x2: tf.Tensor) -> NNCFTensor:
     return TFNNCFTensor(tf.math.maximum(x1.tensor, x2.tensor))
Esempio n. 13
0
 def abs(x: NNCFTensor) -> NNCFTensor:
     return TFNNCFTensor(tf.math.abs(x.tensor))
Esempio n. 14
0
 def reduce_max(x: NNCFTensor, axis: Union[int, tuple, list]) -> NNCFTensor:
     return TFNNCFTensor(tf.reduce_max(x.tensor, axis=axis))