def _set_binary_masks_for_pruned_layers_globally(self, pruning_level: float): """ Sets the binary mask values for layer groups according to the global pruning level. Filter importance scores in each group are merged into a single global list and a threshold value separating the pruning_level proportion of the least important filters in the model is calculated. Filters are pruned globally according to the threshold value. """ nncf_logger.debug( 'Setting new binary masks for all pruned modules together.') filter_importances = {} wrapped_layers = collect_wrapped_layers(self._model) # 0. Remove masks at the elements of the NNCFGraph for node in self._original_graph.topological_sort(): node.data.pop('output_mask', None) # 1. Calculate masks # a. Calculate importances for all groups of filters for group in self._pruned_layer_groups_info.get_all_clusters(): cumulative_filters_importance = self._calculate_filters_importance_in_group( group) filter_importances[group.id] = cumulative_filters_importance # b. Calculate one threshold for all weights importances = tf.concat(list(filter_importances.values()), 0) threshold = sorted(importances)[int(pruning_level * importances.shape[0])] # c. Initialize masks for group in self._pruned_layer_groups_info.get_all_clusters(): filter_mask = calculate_binary_mask(filter_importances[group.id], threshold) for node in group.elements: nncf_node = self._original_graph.get_node_by_id( node.nncf_node_id) nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask) # 2. Propagate masks across the graph mask_propagator = MaskPropagationAlgorithm( self._original_graph, TF_PRUNING_OPERATOR_METATYPES, TFNNCFPruningTensorProcessor) mask_propagator.mask_propagation() # 3. Apply masks to the model nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.name == n.layer_name ][0] if nncf_node.data['output_mask'] is not None: self._set_operation_masks([layer], nncf_node.data['output_mask'].tensor) # Calculate actual flops with new masks self._update_benchmark_statistics()
def _set_binary_masks_for_pruned_layers_groupwise(self, pruning_level: float): nncf_logger.debug('Setting new binary masks for pruned layers.') wrapped_layers = collect_wrapped_layers(self._model) # 0. Removing masks at the elements of the NNCFGraph for node in self._original_graph.topological_sort(): node.data.pop('output_mask', None) # 1. Calculate masks for group in self._pruned_layer_groups_info.get_all_clusters(): # a. Calculate the cumulative importance for all filters in the group cumulative_filters_importance = self._calculate_filters_importance_in_group( group) filters_num = len(cumulative_filters_importance) # b. Calculate threshold num_of_sparse_elems = get_rounded_pruned_element_number( cumulative_filters_importance.shape[0], pruning_level) threshold = sorted(cumulative_filters_importance)[min( num_of_sparse_elems, filters_num - 1)] # c. Initialize masks filter_mask = calculate_binary_mask(cumulative_filters_importance, threshold) for node in group.elements: nncf_node = self._original_graph.get_node_by_id( node.nncf_node_id) nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask) # 2. Propagating masks across the graph mask_propagator = MaskPropagationAlgorithm( self._original_graph, TF_PRUNING_OPERATOR_METATYPES, TFNNCFPruningTensorProcessor) mask_propagator.mask_propagation() # 3. Apply masks to the model nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.name == n.layer_name ][0] if nncf_node.data['output_mask'] is not None: self._set_operation_masks([layer], nncf_node.data['output_mask'].tensor) # Calculate actual flops and weights number with new masks self._update_benchmark_statistics()
def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLayout: """ Computes necessary model transformations (pruning mask insertions) to enable pruning. :param model: The original uncompressed model. :return: The instance of the `TransformationLayout` class containing a list of pruning mask insertions. """ converter = TFModelConverterFactory.create(model) self._graph = converter.convert() groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(self._graph) transformations = TFTransformationLayout() shared_layers = set() self._pruned_layer_groups_info = Clusterization[PrunedLayerInfo](lambda x: x.layer_name) for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()): group_minfos = [] for node in group.elements: layer_name = get_layer_identifier(node) layer = model.get_layer(layer_name) group_minfos.append(PrunedLayerInfo(node.node_name, layer_name, node.node_id, is_prunable_depthwise_conv(node))) # Add output_mask to elements to run mask_propagation # and detect spec_nodes that will be pruned. # It should be done for all elements of shared layer. node.data['output_mask'] = TFNNCFTensor(tf.ones(node.layer_attributes.out_channels)) if layer_name in shared_layers: continue if node.is_shared(): shared_layers.add(layer_name) # Check that we need to prune weights in this op assert self._is_pruned_layer(layer) nncf_logger.info('Adding Weight Pruner in: %s', layer_name) _, layer_info = converter.get_layer_info_for_node(node.node_name) for weight_def in node.metatype.weight_definitions: transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, weight_def.weight_attr_name) ) if node.metatype.bias_attr_name is not None and \ getattr(layer, node.metatype.bias_attr_name) is not None: transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, node.metatype.bias_attr_name) ) cluster = Cluster[PrunedLayerInfo](i, group_minfos, [n.node_id for n in group.elements]) self._pruned_layer_groups_info.add_cluster(cluster) # Propagating masks across the graph to detect spec_nodes that will be pruned mask_propagator = MaskPropagationAlgorithm(self._graph, TF_PRUNING_OPERATOR_METATYPES, TFNNCFPruningTensorProcessor) mask_propagator.mask_propagation() # Add masks for all spec modules, because prunable batchnorm layers can be determined # at the moment of mask propagation types_spec_layers = [TFBatchNormalizationLayerMetatype] \ if self._prune_batch_norms else [] spec_nodes = self._graph.get_nodes_by_metatypes(types_spec_layers) for spec_node in spec_nodes: layer_name = get_layer_identifier(spec_node) layer = model.get_layer(layer_name) if spec_node.data['output_mask'] is None: # Skip elements that will not be pruned continue if layer_name in shared_layers: continue if spec_node.is_shared(): shared_layers.add(layer_name) nncf_logger.info('Adding Weight Pruner in: %s', layer_name) _, layer_info = converter.get_layer_info_for_node(spec_node.node_name) for weight_def in spec_node.metatype.weight_definitions: if spec_node.metatype is TFBatchNormalizationLayerMetatype \ and not layer.scale and weight_def.weight_attr_name == 'gamma': nncf_logger.debug('Fused gamma parameter encountered in BatchNormalization layer. ' 'Do not add mask to it.') continue transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, weight_def.weight_attr_name) ) transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, spec_node.metatype.bias_attr_name) ) return transformations
def repeat(cls, tensor: NNCFTensor, repeats: int) -> NNCFTensor: ret_tensor = tf.repeat(tensor, repeats=repeats) return TFNNCFTensor(ret_tensor)
def ones(cls, shape: Union[int, List[int]], device: tf.device) -> NNCFTensor: with tf.device(device): return TFNNCFTensor(tf.ones(shape))
def concatenate(cls, tensors: List[NNCFTensor], axis: int) -> NNCFTensor: # pylint: disable=E1120,E1123 ret_tensor = tf.concat([t.tensor for t in tensors], axis=axis) return TFNNCFTensor(ret_tensor)
def _set_binary_masks_for_pruned_modules_globally_by_flops_target( self, target_flops_pruning_level: float): """ Prunes least important filters one-by-one until target FLOPs pruning level is achieved. Filters are sorted by filter importance score. """ nncf_logger.debug('Setting new binary masks for pruned layers.') target_flops = self.full_flops * (1 - target_flops_pruning_level) wrapped_layers = collect_wrapped_layers(self._model) masks = {} nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.name == n.layer_name ][0] nncf_node.data['output_mask'] = TFNNCFTensor( tf.ones(get_filters_num(layer))) # 1. Calculate importances for all groups of filters. Initialize masks. filter_importances = [] group_indexes = [] filter_indexes = [] for group in self._pruned_layer_groups_info.get_all_clusters(): cumulative_filters_importance = self._calculate_filters_importance_in_group( group) filter_importances.extend(cumulative_filters_importance) filters_num = len(cumulative_filters_importance) group_indexes.extend([group.id] * filters_num) filter_indexes.extend(range(filters_num)) masks[group.id] = tf.ones(filters_num) # 2. tmp_in_channels = self._layers_in_channels.copy() tmp_out_channels = self._layers_out_channels.copy() sorted_importances = sorted(zip(filter_importances, group_indexes, filter_indexes), key=lambda x: x[0]) for _, group_id, filter_index in sorted_importances: if self._pruning_quotas[group_id] == 0: continue masks[group_id] = tf.tensor_scatter_nd_update( masks[group_id], [[filter_index]], [0]) self._pruning_quotas[group_id] -= 1 # Update input/output shapes of pruned elements group = self._pruned_layer_groups_info.get_cluster_by_id(group_id) for node in group.elements: tmp_out_channels[node.node_name] -= 1 if node.is_depthwise: tmp_in_channels[node.node_name] -= 1 for node_name in self._next_nodes[group_id]: tmp_in_channels[node_name] -= 1 flops, params_num = count_flops_and_weights( self._original_graph, self._layers_in_shapes, self._layers_out_shapes, input_channels=tmp_in_channels, output_channels=tmp_out_channels, conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, linear_op_metatypes=LINEAR_LAYER_METATYPES) if flops <= target_flops: # 3. Add masks to the graph and propagate them for group in self._pruned_layer_groups_info.get_all_clusters(): for node in group.elements: nncf_node = self._original_graph.get_node_by_id( node.nncf_node_id) nncf_node.data['output_mask'] = TFNNCFTensor( masks[group.id]) mask_propagator = MaskPropagationAlgorithm( self._original_graph, TF_PRUNING_OPERATOR_METATYPES, TFNNCFPruningTensorProcessor) mask_propagator.mask_propagation() # 4. Set binary masks to the model self.current_flops = flops self.current_params_num = params_num nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.name == n.layer_name ][0] if nncf_node.data['output_mask'] is not None: self._set_operation_masks( [layer], nncf_node.data['output_mask'].tensor) return raise RuntimeError( f'Unable to prune model to required flops pruning level:' f' {target_flops_pruning_level}')
def unstack(x: NNCFTensor, axis: int = 0) -> List[NNCFTensor]: tensor_list = tf.unstack(x.tensor, axis=axis) return [TFNNCFTensor(t) for t in tensor_list]
def _register_input(self, x: tf.Tensor): self._register_input_common(TFNNCFTensor(x))
def stack(x: Union[List[tf.Tensor], Deque[tf.Tensor]], axis: int = 0) -> NNCFTensor: x = [t.tensor for t in x] return TFNNCFTensor(tf.stack(x, axis=axis))
def mean(x: NNCFTensor, axis: Union[int, tuple, list]) -> NNCFTensor: return TFNNCFTensor(tf.math.reduce_mean(x.tensor, axis=axis))
def max(x1: tf.Tensor, x2: tf.Tensor) -> NNCFTensor: return TFNNCFTensor(tf.math.maximum(x1.tensor, x2.tensor))
def abs(x: NNCFTensor) -> NNCFTensor: return TFNNCFTensor(tf.math.abs(x.tensor))
def reduce_max(x: NNCFTensor, axis: Union[int, tuple, list]) -> NNCFTensor: return TFNNCFTensor(tf.reduce_max(x.tensor, axis=axis))