def _calculate_flops_and_weights_in_uniformly_pruned_model( self, pruning_level: float) -> Tuple[int, int]: """ Prune all prunable modules in model by pruning_level level and returns number of weights and flops of the pruned model. :param pruning_level: proportion of zero filters in all modules :return: flops number in pruned model """ tmp_in_channels, tmp_out_channels = \ calculate_in_out_channels_in_uniformly_pruned_model( pruning_groups=self.pruned_module_groups_info.get_all_clusters(), pruning_level=pruning_level, full_input_channels=self._modules_in_channels, full_output_channels=self._modules_out_channels, pruning_groups_next_nodes=self.next_nodes) return count_flops_and_weights( self._model.get_original_graph(), self._modules_in_shapes, self._modules_out_shapes, input_channels=tmp_in_channels, output_channels=tmp_out_channels, conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, linear_op_metatypes=LINEAR_LAYER_METATYPES)
def _calculate_flops_and_weights_in_uniformly_pruned_model( self, pruning_level): tmp_in_channels, tmp_out_channels = \ calculate_in_out_channels_in_uniformly_pruned_model( pruning_groups=self._pruned_layer_groups_info.get_all_clusters(), pruning_level=pruning_level, full_input_channels=self._layers_in_channels, full_output_channels=self._layers_out_channels, pruning_groups_next_nodes=self._next_nodes) return count_flops_and_weights( self._original_graph, self._layers_in_shapes, self._layers_out_shapes, input_channels=tmp_in_channels, output_channels=tmp_out_channels, conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, linear_op_metatypes=LINEAR_LAYER_METATYPES)
def _update_benchmark_statistics(self): tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks( pruning_groups=self.pruned_module_groups_info.get_all_clusters(), masks=self._collect_pruning_masks(), tensor_processor=PTNNCFCollectorTensorProcessor, full_input_channels=self._modules_in_channels, full_output_channels=self._modules_out_channels, pruning_groups_next_nodes=self.next_nodes) self.current_filters_num = count_filters_num( self._model.get_original_graph(), op_metatypes=GENERAL_CONV_LAYER_METATYPES, output_channels=tmp_out_channels) self.current_flops, self.current_params_num = \ count_flops_and_weights(self._model.get_original_graph(), self._modules_in_shapes, self._modules_out_shapes, input_channels=tmp_in_channels, output_channels=tmp_out_channels, conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, linear_op_metatypes=LINEAR_LAYER_METATYPES)
def test_calculation_of_flops(all_weights, pruning_flops_target, ref_flops, ref_params_num): """ Test for pruning masks check (_set_binary_masks_for_filters, _set_binary_masks_for_all_filters_together). :param all_weights: whether mask will be calculated for all weights in common or not :param pruning_flops_target: prune model by flops, if None then by number of channels :param ref_flops: reference size of model """ config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) config['compression']['params']['all_weights'] = all_weights config['compression']['params']['prune_first_conv'] = True config['compression']['pruning_init'] = 0.5 if pruning_flops_target: config['compression']['params'][ 'pruning_flops_target'] = pruning_flops_target _, pruning_algo, _ = create_pruning_algo_with_config(config) assert pruning_algo.current_flops == ref_flops assert pruning_algo.current_params_num == ref_params_num # pylint:disable=protected-access tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks( pruning_algo.pruned_module_groups_info.get_all_clusters(), masks=pruning_algo._collect_pruning_masks(), tensor_processor=PTNNCFCollectorTensorProcessor, full_input_channels=pruning_algo._modules_in_channels, full_output_channels=pruning_algo._modules_out_channels, pruning_groups_next_nodes=pruning_algo.next_nodes) cur_flops, cur_params_num = count_flops_and_weights( pruning_algo._model.get_original_graph(), pruning_algo._modules_in_shapes, pruning_algo._modules_out_shapes, input_channels=tmp_in_channels, output_channels=tmp_out_channels, conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, linear_op_metatypes=LINEAR_LAYER_METATYPES) assert (cur_flops, cur_params_num) == (ref_flops, ref_params_num)
def test_flops_calulation_for_spec_layers(model_fn, all_weights, pruning_flops_target, ref_full_flops, ref_current_flops, ref_full_params, ref_current_params): config = get_basic_pruning_config(8) config['compression']['algorithm'] = 'filter_pruning' config['compression']['pruning_init'] = pruning_flops_target config['compression']['params']['pruning_flops_target'] = pruning_flops_target config['compression']['params']['prune_first_conv'] = True config['compression']['params']['all_weights'] = all_weights input_shape = [1, 8, 8, 1] model = model_fn(input_shape) model.compile() _, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) assert compression_ctrl.full_flops == ref_full_flops assert compression_ctrl.full_params_num == ref_full_params assert compression_ctrl.current_flops == ref_current_flops assert compression_ctrl.current_params_num == ref_current_params # pylint:disable=protected-access tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks( pruning_groups=compression_ctrl._pruned_layer_groups_info.get_all_clusters(), masks=compression_ctrl._collect_pruning_masks(), tensor_processor=TFNNCFCollectorTensorProcessor, full_input_channels=compression_ctrl._layers_in_channels, full_output_channels=compression_ctrl._layers_out_channels, pruning_groups_next_nodes=compression_ctrl._next_nodes) cur_flops, cur_params_num = \ count_flops_and_weights(compression_ctrl._original_graph, compression_ctrl._layers_in_shapes, compression_ctrl._layers_out_shapes, input_channels=tmp_in_channels, output_channels=tmp_out_channels, conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, linear_op_metatypes=LINEAR_LAYER_METATYPES) assert (cur_flops, cur_params_num) == (ref_current_flops, ref_current_params)
def _set_binary_masks_for_pruned_modules_globally_by_flops_target( self, target_flops_pruning_level: float) -> None: """ Sorting all prunable filters in the network by importance and pruning the amount of the least important filters sufficient to achieve the target pruning level by flops. Filters are pruned one-by-one and the corresponding flops value is checked. :param target_flops_pruning_level: target proportion of flops removed from the model :return: """ target_flops = self.full_flops * (1 - target_flops_pruning_level) # 1. Initialize masks for minfo in self.pruned_module_groups_info.get_all_nodes(): new_mask = torch.ones(get_filters_num(minfo.module)).to( minfo.module.weight.device) self.set_mask(minfo, new_mask) # 2. Calculate filter importances for all prunable groups filter_importances = [] cluster_indexes = [] filter_indexes = [] for cluster in self.pruned_module_groups_info.get_all_clusters(): filters_num = torch.tensor( [get_filters_num(minfo.module) for minfo in cluster.elements]) assert torch.all(filters_num == filters_num[0]) device = cluster.elements[0].module.weight.device cumulative_filters_importance = torch.zeros( filters_num[0]).to(device) # Calculate cumulative importance for all filters in this group for minfo in cluster.elements: weight = minfo.module.weight if self.normalize_weights: weight = self.weights_normalizer(weight) filters_importance = self.filter_importance( weight, minfo.module.target_weight_dim_for_compression) scaled_importance = self.ranking_coeffs[minfo.node_name][0] * filters_importance + \ self.ranking_coeffs[minfo.node_name][1] cumulative_filters_importance += scaled_importance filter_importances.append(cumulative_filters_importance) cluster_indexes.append( cluster.id * torch.ones_like(cumulative_filters_importance)) filter_indexes.append( torch.arange(len(cumulative_filters_importance))) importances = torch.cat(filter_importances) cluster_indexes = torch.cat(cluster_indexes) filter_indexes = torch.cat(filter_indexes) # 3. Sort all filter groups by importances and prune the least important filters # until target flops pruning level is achieved sorted_importances = sorted(zip(importances, cluster_indexes, filter_indexes), key=lambda x: x[0]) cur_num = 0 tmp_in_channels = self._modules_in_channels.copy() tmp_out_channels = self._modules_out_channels.copy() tmp_pruning_quotas = self.pruning_quotas.copy() while cur_num < len(sorted_importances): cluster_idx = int(sorted_importances[cur_num][1]) filter_idx = int(sorted_importances[cur_num][2]) if tmp_pruning_quotas[cluster_idx] > 0: tmp_pruning_quotas[cluster_idx] -= 1 else: cur_num += 1 continue cluster = self.pruned_module_groups_info.get_cluster_by_id( cluster_idx) for node in cluster.elements: tmp_out_channels[node.node_name] -= 1 if node.is_depthwise: tmp_in_channels[node.node_name] -= 1 node.operand.binary_filter_pruning_mask[filter_idx] = 0 # Prune in channels in all next nodes next_nodes = self.next_nodes[cluster.id] for node_id in next_nodes: tmp_in_channels[node_id] -= 1 flops, params_num = count_flops_and_weights( self._model.get_original_graph(), self._modules_in_shapes, self._modules_out_shapes, input_channels=tmp_in_channels, output_channels=tmp_out_channels, conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, linear_op_metatypes=LINEAR_LAYER_METATYPES) if flops < target_flops: self.current_flops = flops self.current_params_num = params_num return cur_num += 1 raise RuntimeError("Can't prune model to asked flops pruning level")
def _set_binary_masks_for_pruned_modules_globally_by_flops_target( self, target_flops_pruning_level: float): """ Prunes least important filters one-by-one until target FLOPs pruning level is achieved. Filters are sorted by filter importance score. """ nncf_logger.debug('Setting new binary masks for pruned layers.') target_flops = self.full_flops * (1 - target_flops_pruning_level) wrapped_layers = collect_wrapped_layers(self._model) masks = {} nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.name == n.layer_name ][0] nncf_node.data['output_mask'] = TFNNCFTensor( tf.ones(get_filters_num(layer))) # 1. Calculate importances for all groups of filters. Initialize masks. filter_importances = [] group_indexes = [] filter_indexes = [] for group in self._pruned_layer_groups_info.get_all_clusters(): cumulative_filters_importance = self._calculate_filters_importance_in_group( group) filter_importances.extend(cumulative_filters_importance) filters_num = len(cumulative_filters_importance) group_indexes.extend([group.id] * filters_num) filter_indexes.extend(range(filters_num)) masks[group.id] = tf.ones(filters_num) # 2. tmp_in_channels = self._layers_in_channels.copy() tmp_out_channels = self._layers_out_channels.copy() sorted_importances = sorted(zip(filter_importances, group_indexes, filter_indexes), key=lambda x: x[0]) for _, group_id, filter_index in sorted_importances: if self._pruning_quotas[group_id] == 0: continue masks[group_id] = tf.tensor_scatter_nd_update( masks[group_id], [[filter_index]], [0]) self._pruning_quotas[group_id] -= 1 # Update input/output shapes of pruned elements group = self._pruned_layer_groups_info.get_cluster_by_id(group_id) for node in group.elements: tmp_out_channels[node.node_name] -= 1 if node.is_depthwise: tmp_in_channels[node.node_name] -= 1 for node_name in self._next_nodes[group_id]: tmp_in_channels[node_name] -= 1 flops, params_num = count_flops_and_weights( self._original_graph, self._layers_in_shapes, self._layers_out_shapes, input_channels=tmp_in_channels, output_channels=tmp_out_channels, conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, linear_op_metatypes=LINEAR_LAYER_METATYPES) if flops <= target_flops: # 3. Add masks to the graph and propagate them for group in self._pruned_layer_groups_info.get_all_clusters(): for node in group.elements: nncf_node = self._original_graph.get_node_by_id( node.nncf_node_id) nncf_node.data['output_mask'] = TFNNCFTensor( masks[group.id]) mask_propagator = MaskPropagationAlgorithm( self._original_graph, TF_PRUNING_OPERATOR_METATYPES, TFNNCFPruningTensorProcessor) mask_propagator.mask_propagation() # 4. Set binary masks to the model self.current_flops = flops self.current_params_num = params_num nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.name == n.layer_name ][0] if nncf_node.data['output_mask'] is not None: self._set_operation_masks( [layer], nncf_node.data['output_mask'].tensor) return raise RuntimeError( f'Unable to prune model to required flops pruning level:' f' {target_flops_pruning_level}')