def _collect_all_weights(self): all_weights = [] for wrapped_layer in collect_wrapped_layers(self._model): for weight_attr, ops in wrapped_layer.weights_attr_ops.items(): for op_name in ops: if op_name in self._op_names: all_weights.append(tf.reshape( self._weight_importance_fn(wrapped_layer.layer_weights[weight_attr]), [-1])) return all_weights
def _set_masks_for_threshold(self, threshold_val): for wrapped_layer in collect_wrapped_layers(self._model): for weight_attr, ops in wrapped_layer.weights_attr_ops.items(): weight = wrapped_layer.layer_weights[weight_attr] for op_name, op in ops.items(): if isinstance(op, BinaryMask): wrapped_layer.ops_weights[op_name].assign( calc_magnitude_binary_mask(weight, self.weight_importance, threshold_val))
def _set_binary_masks_for_pruned_layers_globally(self, pruning_rate: float): """ Sets the binary mask values for layer groups according to the global pruning rate. Filter importance scores in each group are merged into a single global list and a threshold value separating the pruning_rate proportion of the least important filters in the model is calculated. Filters are pruned globally according to the threshold value. """ nncf_logger.debug( 'Setting new binary masks for all pruned modules together.') filter_importances = {} wrapped_layers = collect_wrapped_layers(self._model) # 0. Remove masks at the nodes of the NNCFGraph for node in self._original_graph.topological_sort(): node.data.pop('output_mask', None) # 1. Calculate masks # a. Calculate importances for all groups of filters for group in self._pruned_layer_groups_info.get_all_clusters(): cumulative_filters_importance = \ self._calculate_filters_importance_in_group(group, wrapped_layers) filter_importances[group.id] = cumulative_filters_importance # b. Calculate one threshold for all weights importances = tf.concat(list(filter_importances.values()), 0) threshold = sorted(importances)[int(pruning_rate * importances.shape[0])] # c. Initialize masks for group in self._pruned_layer_groups_info.get_all_clusters(): filter_mask = calculate_binary_mask(filter_importances[group.id], threshold) for node in group.nodes: nncf_node = self._original_graph.get_node_by_id( node.nncf_node_id) nncf_node.data['output_mask'] = filter_mask # 2. Propagate masks across the graph mask_propagator = MaskPropagationAlgorithm( self._original_graph, TF_PRUNING_OPERATOR_METATYPES) mask_propagator.mask_propagation() # 3. Apply masks to the model nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.layer.name == get_original_name(n.node_name) ][0] if nncf_node.data['output_mask'] is not None: self._set_operation_masks([layer], nncf_node.data['output_mask'])
def _collect_all_weights(self): all_weights = [] for wrapped_layer in collect_wrapped_layers(self._model): for weight_attr, ops in wrapped_layer.weights_attr_ops.items(): for op in ops.values(): if isinstance(op, BinaryMask): all_weights.append( tf.reshape( self.weight_importance( wrapped_layer.layer_weights[weight_attr]), [-1])) return all_weights
def _set_masks_for_threshold(self, threshold_val): for wrapped_layer in collect_wrapped_layers(self._model): for weight_attr, ops in wrapped_layer.weights_attr_ops.items(): weight = wrapped_layer.layer_weights[weight_attr] for op_name in ops: if op_name in self._op_names: wrapped_layer.ops_weights[op_name]['mask'].assign( calc_magnitude_binary_mask(weight, self._weight_importance_fn, threshold_val) )
def apply_fn_to_op_weights(model: tf.keras.Model, op_names: List[str], fn=lambda x: x): sparsifyed_layers = collect_wrapped_layers(model) target_ops = [] for layer in sparsifyed_layers: for ops in layer.weights_attr_ops.values(): for op in ops.values(): if op.name in op_names: weight = layer.get_operation_weights(op.name) target_ops.append((op, fn(weight))) return target_ops
def raw_statistics(self): raw_sparsity_statistics = {} sparsity_levels = [] mask_names = [] weights_shapes = [] weights_numbers = [] total_weights_number = tf.constant(0) total_sparsified_weights_number = tf.constant(0) total_bkup_weights_number = tf.constant(0) wrapped_layers = collect_wrapped_layers(self._model) for wrapped_layer in wrapped_layers: for ops in wrapped_layer.weights_attr_ops.values(): for op_name, op in ops.items(): if op_name in self._op_names: if isinstance(op, BinaryMaskWithWeightsBackup): total_bkup_weights_number += tf.size(op.bkup_var) if isinstance(op, BinaryMask): mask = wrapped_layer.ops_weights[op_name]['mask'] mask_names.append(mask.name) weights_shapes.append(list(mask.shape)) weights_number = tf.size(mask) weights_numbers.append(weights_number) sparsified_weights_number = weights_number - tf.reduce_sum(tf.cast(mask, tf.int32)) sparsity_levels.append(sparsified_weights_number / weights_number) total_weights_number += weights_number total_sparsified_weights_number += sparsified_weights_number sparsity_rate_for_sparsified_modules = (total_sparsified_weights_number / total_weights_number).numpy() model_weights_number = count_params(self._model.weights) - total_weights_number - total_bkup_weights_number sparsity_rate_for_model = (total_sparsified_weights_number / model_weights_number).numpy() raw_sparsity_statistics.update({ 'sparsity_rate_for_sparsified_modules': sparsity_rate_for_sparsified_modules, 'sparsity_rate_for_model': sparsity_rate_for_model, 'sparsity_threshold': self._threshold }) sparsity_levels = tf.keras.backend.batch_get_value(sparsity_levels) weights_percentages = [weights_number / total_weights_number * 100 for weights_number in weights_numbers] weights_percentages = tf.keras.backend.batch_get_value(weights_percentages) mask_sparsity = list(zip(mask_names, weights_shapes, sparsity_levels, weights_percentages)) raw_sparsity_statistics['sparsity_statistic_by_layer'] = [] for mask_name, weights_shape, sparsity_level, weights_percentage in mask_sparsity: raw_sparsity_statistics['sparsity_statistic_by_layer'].append({ 'Name': mask_name, 'Weight\'s Shape': weights_shape, 'SR': sparsity_level, '% weights': weights_percentage }) return raw_sparsity_statistics
def raw_statistics(self) -> Dict[str, object]: raw_pruning_statistics = {} pruning_rates = [] mask_names = [] weights_shapes = [] mask_shapes = [] wrapped_layers = collect_wrapped_layers(self._model) for wrapped_layer in wrapped_layers: for weight_attr, ops in wrapped_layer.weights_attr_ops.items(): for op_name in ops: if op_name in self._op_names: mask = wrapped_layer.ops_weights[op_name]['mask'] mask_names.append(mask.name) weights_shapes.append(list(mask.shape)) reduce_axes = list(range(len(mask.shape))) filter_axis = get_filter_axis(wrapped_layer, weight_attr) if filter_axis == -1: filter_axis = reduce_axes[filter_axis] reduce_axes.remove(filter_axis) filter_mask = tf.reduce_max(tf.cast(mask, tf.int32), axis=reduce_axes, keepdims=True) mask_shapes.append(list(filter_mask.shape)) filters_number = get_filters_num(wrapped_layer) pruned_filters_number = filters_number - tf.reduce_sum( filter_mask) pruning_rates.append(pruned_filters_number / filters_number) raw_pruning_statistics.update({'pruning_rate': self.pruning_rate}) pruning_rates = tf.keras.backend.batch_get_value(pruning_rates) mask_pruning = list( zip(mask_names, weights_shapes, mask_shapes, pruning_rates)) raw_pruning_statistics['pruning_statistic_by_layer'] = [] for mask_name, weights_shape, mask_shape, pruning_rate in mask_pruning: raw_pruning_statistics['pruning_statistic_by_layer'].append({ 'Name': mask_name, 'Weight\'s Shape': weights_shape, 'Mask Shape': mask_shape, 'PR': pruning_rate }) return raw_pruning_statistics
def _set_binary_masks_for_pruned_layers_groupwise(self, pruning_rate: float): nncf_logger.debug('Setting new binary masks for pruned layers.') wrapped_layers = collect_wrapped_layers(self._model) # 0. Removing masks at the nodes of the NNCFGraph for node in self._original_graph.topological_sort(): node.data.pop('output_mask', None) # 1. Calculate masks for group in self._pruned_layer_groups_info.get_all_clusters(): # a. Calculate the cumulative importance for all filters in the group cumulative_filters_importance = \ self._calculate_filters_importance_in_group(group, wrapped_layers) filters_num = len(cumulative_filters_importance) # b. Calculate threshold num_of_sparse_elems = get_rounded_pruned_element_number( cumulative_filters_importance.shape[0], pruning_rate) threshold = sorted(cumulative_filters_importance)[min( num_of_sparse_elems, filters_num - 1)] # c. Initialize masks filter_mask = calculate_binary_mask(cumulative_filters_importance, threshold) for node in group.nodes: nncf_node = self._original_graph.get_node_by_id( node.nncf_node_id) nncf_node.data['output_mask'] = filter_mask # 2. Propagating masks across the graph mask_propagator = MaskPropagationAlgorithm( self._original_graph, TF_PRUNING_OPERATOR_METATYPES) mask_propagator.mask_propagation() # 3. Apply masks to the model nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.layer.name == get_original_name(n.node_name) ][0] if nncf_node.data['output_mask'] is not None: self._set_operation_masks([layer], nncf_node.data['output_mask'])
def _set_binary_masks_for_pruned_modules_globally_by_flops_target( self, target_flops_pruning_rate: float): """ Prunes least important filters one-by-one until target FLOPs pruning rate is achieved. Filters are sorted by filter importance score. """ nncf_logger.debug('Setting new binary masks for pruned layers.') target_flops = self.full_flops * (1 - target_flops_pruning_rate) wrapped_layers = collect_wrapped_layers(self._model) masks = [] nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.layer.name == get_original_name(n.node_name) ][0] nncf_node.data['output_mask'] = tf.ones(get_filters_num(layer)) # 1. Calculate importances for all groups of filters. Initialize masks. filter_importances = [] group_indexes = [] filter_indexes = [] for group in self._pruned_layer_groups_info.get_all_clusters(): cumulative_filters_importance = \ self._calculate_filters_importance_in_group(group, wrapped_layers) filter_importances.extend(cumulative_filters_importance) filters_num = len(cumulative_filters_importance) group_indexes.extend([group.id] * filters_num) filter_indexes.extend(range(filters_num)) masks[group.id] = tf.ones(filters_num) # 2. tmp_in_channels = self._layers_in_channels.copy() tmp_out_channels = self._layers_out_channels.copy() sorted_importances = sorted(zip(filter_importances, group_indexes, filter_indexes), key=lambda x: x[0]) for _, group_id, filter_index in sorted_importances: if self._pruning_quotas[group_id] == 0: continue masks[group_id][filter_index] = 0 self._pruning_quotas[group_id] -= 1 # Update input/output shapes of pruned nodes group = self._pruned_layer_groups_info.get_cluster_by_id(group_id) for node in group.nodes: tmp_out_channels[node.layer_name] -= 1 for node_name in self._next_nodes[group_id]: tmp_in_channels[node_name] -= 1 flops = sum( count_flops_for_nodes(self._original_graph, self._layers_in_shapes, self._layers_out_shapes, input_channels=tmp_in_channels, output_channels=tmp_out_channels, conv_op_types=GENERAL_CONV_LAYERS, linear_op_types=LINEAR_LAYERS).values()) if flops <= target_flops: # 3. Add masks to the graph and propagate them for group in self._pruned_layer_groups_info.get_all_clusters(): for node in group.nodes: nncf_node = self._original_graph.get_node_by_id( node.nncf_node_id) nncf_node.data['output_mask'] = masks[group.id] mask_propagator = MaskPropagationAlgorithm( self._original_graph, TF_PRUNING_OPERATOR_METATYPES) mask_propagator.mask_propagation() # 4. Set binary masks to the model self.current_flops = flops nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.layer.name == get_original_name(n.node_name) ][0] if nncf_node.data['output_mask'] is not None: self._set_operation_masks( [layer], nncf_node.data['output_mask']) return raise RuntimeError( f'Unable to prune model to required flops pruning rate:' f' {target_flops_pruning_rate}')
def raw_statistics(self): raw_sparsity_statistics = {} sparsity_levels = [] mask_names = [] weights_shapes = [] weights_numbers = [] sparse_prob_sum = tf.constant(0.) total_weights_number = tf.constant(0) total_sparsified_weights_number = tf.constant(0) wrapped_layers = collect_wrapped_layers(self._model) for wrapped_layer in wrapped_layers: for ops in wrapped_layer.weights_attr_ops.values(): for op_name in ops: if op_name in self._op_names: mask = wrapped_layer.ops_weights[op_name]['mask'] sw_loss = tf.reduce_sum(binary_mask(mask)) weights_number = tf.size(mask) sparsified_weights_number = weights_number - tf.cast( sw_loss, tf.int32) mask_names.append(wrapped_layer.name + '_rb_mask') weights_shapes.append(list(mask.shape)) weights_numbers.append(weights_number) sparsity_levels.append(sparsified_weights_number / weights_number) sparse_prob_sum += tf.math.reduce_sum( tf.math.sigmoid(mask)) total_weights_number += weights_number total_sparsified_weights_number += sparsified_weights_number sparsity_rate_for_sparsified_modules = ( total_sparsified_weights_number / total_weights_number).numpy() model_weights_number = count_params( self._model.weights) - total_weights_number sparsity_rate_for_model = (total_sparsified_weights_number / model_weights_number).numpy() mean_sparse_prob = (sparse_prob_sum / tf.cast(total_weights_number, tf.float32)).numpy() raw_sparsity_statistics.update({ 'sparsity_rate_for_sparsified_modules': sparsity_rate_for_sparsified_modules, 'sparsity_rate_for_model': sparsity_rate_for_model, 'mean_sparse_prob': mean_sparse_prob, 'target_sparsity_rate': self.loss.target_sparsity_rate, }) sparsity_levels = tf.keras.backend.batch_get_value(sparsity_levels) weights_percentages = [ weights_number / total_weights_number * 100 for weights_number in weights_numbers ] weights_percentages = tf.keras.backend.batch_get_value( weights_percentages) mask_sparsity = list( zip(mask_names, weights_shapes, sparsity_levels, weights_percentages)) raw_sparsity_statistics['sparsity_statistic_by_layer'] = [] for mask_name, weights_shape, sparsity_level, weights_percentage in mask_sparsity: raw_sparsity_statistics['sparsity_statistic_by_layer'].append({ 'Name': mask_name, 'Weight\'s Shape': weights_shape, 'SR': sparsity_level, '% weights': weights_percentage }) return raw_sparsity_statistics