def test_get_rounded_pruned_element_number(total, sparsity_rate, multiple_of, ref): if multiple_of is not None: result = get_rounded_pruned_element_number(total, sparsity_rate, multiple_of) else: result = get_rounded_pruned_element_number(total, sparsity_rate) assert ref == result if multiple_of is not None: assert (total - result) % multiple_of == 0
def _calculate_flops_in_uniformly_pruned_model(self, pruning_rate): """ Prune all prunable modules in model with pruning_rate rate and returns flops of pruned model. :param pruning_rate: proportion of zero filters in all modules :return: flops number in pruned model """ tmp_in_channels = self.modules_in_channels.copy() tmp_out_channels = self.modules_out_channels.copy() for group in self.pruned_module_groups_info.get_all_clusters(): assert all([ tmp_out_channels[group.nodes[0].module_scope] == tmp_out_channels[node.module_scope] for node in group.nodes ]) # prune all nodes in cluster (by output channels) old_out_channels = self.modules_out_channels[ group.nodes[0].module_scope] num_of_sparse_elems = get_rounded_pruned_element_number( old_out_channels, pruning_rate) new_out_channels_num = old_out_channels - num_of_sparse_elems for node in group.nodes: tmp_out_channels[node.module_scope] = new_out_channels_num # Prune in_channels in all next nodes of cluster next_nodes = self.next_nodes[group.id] for node_id in next_nodes: tmp_in_channels[node_id] = new_out_channels_num flops = self._calculate_flops_in_pruned_model(tmp_in_channels, tmp_out_channels) return flops
def _set_binary_masks_for_filters(self, pruning_rate): nncf_logger.debug("Setting new binary masks for pruned modules.") with torch.no_grad(): for group in self.pruned_module_groups_info.get_all_clusters(): filters_num = torch.tensor( [get_filters_num(minfo.module) for minfo in group.nodes]) assert torch.all(filters_num == filters_num[0]) device = group.nodes[0].module.weight.device cumulative_filters_importance = torch.zeros( filters_num[0]).to(device) # 1. Calculate cumulative importance for all filters in group for minfo in group.nodes: filters_importance = self.filter_importance( minfo.module.weight, minfo.module.target_weight_dim_for_compression) cumulative_filters_importance += filters_importance # 2. Calculate threshold num_of_sparse_elems = get_rounded_pruned_element_number( cumulative_filters_importance.size(0), pruning_rate) threshold = sorted(cumulative_filters_importance)[min( num_of_sparse_elems, filters_num[0] - 1)] mask = calculate_binary_mask(cumulative_filters_importance, threshold) # 3. Set binary masks for filter for minfo in group.nodes: pruning_module = minfo.operand pruning_module.binary_filter_pruning_mask = mask # Calculate actual flops with new masks self.current_flops = self._calculate_flops_pruned_model_by_masks()
def _set_binary_masks_for_filters(self): nncf_logger.debug("Setting new binary masks for pruned modules.") with torch.no_grad(): for minfo in self.pruned_module_info: pruning_module = minfo.operand # 1. Calculate importance for all filters in all weights # 2. Calculate thresholds for every weight # 3. Set binary masks for filter filters_importance = self.filter_importance(minfo.module.weight) num_of_sparse_elems = get_rounded_pruned_element_number(filters_importance.size(0), self.pruning_rate) threshold = sorted(filters_importance)[num_of_sparse_elems] pruning_module.binary_filter_pruning_mask = calculate_binary_mask(filters_importance, threshold)