Пример #1
0
def test_get_rounded_pruned_element_number(total, sparsity_rate, multiple_of, ref):
    if multiple_of is not None:
        result = get_rounded_pruned_element_number(total, sparsity_rate, multiple_of)
    else:
        result = get_rounded_pruned_element_number(total, sparsity_rate)
    assert ref == result

    if multiple_of is not None:
        assert (total - result) % multiple_of == 0
Пример #2
0
    def _calculate_flops_in_uniformly_pruned_model(self, pruning_rate):
        """
        Prune all prunable modules in model with pruning_rate rate and returns flops of pruned model.
        :param pruning_rate: proportion of zero filters in all modules
        :return: flops number in pruned model
        """
        tmp_in_channels = self.modules_in_channels.copy()
        tmp_out_channels = self.modules_out_channels.copy()

        for group in self.pruned_module_groups_info.get_all_clusters():
            assert all([
                tmp_out_channels[group.nodes[0].module_scope] ==
                tmp_out_channels[node.module_scope] for node in group.nodes
            ])
            # prune all nodes in cluster (by output channels)
            old_out_channels = self.modules_out_channels[
                group.nodes[0].module_scope]
            num_of_sparse_elems = get_rounded_pruned_element_number(
                old_out_channels, pruning_rate)
            new_out_channels_num = old_out_channels - num_of_sparse_elems

            for node in group.nodes:
                tmp_out_channels[node.module_scope] = new_out_channels_num

            # Prune in_channels in all next nodes of cluster
            next_nodes = self.next_nodes[group.id]
            for node_id in next_nodes:
                tmp_in_channels[node_id] = new_out_channels_num
        flops = self._calculate_flops_in_pruned_model(tmp_in_channels,
                                                      tmp_out_channels)
        return flops
Пример #3
0
    def _set_binary_masks_for_filters(self, pruning_rate):
        nncf_logger.debug("Setting new binary masks for pruned modules.")

        with torch.no_grad():
            for group in self.pruned_module_groups_info.get_all_clusters():
                filters_num = torch.tensor(
                    [get_filters_num(minfo.module) for minfo in group.nodes])
                assert torch.all(filters_num == filters_num[0])
                device = group.nodes[0].module.weight.device

                cumulative_filters_importance = torch.zeros(
                    filters_num[0]).to(device)
                # 1. Calculate cumulative importance for all filters in group
                for minfo in group.nodes:
                    filters_importance = self.filter_importance(
                        minfo.module.weight,
                        minfo.module.target_weight_dim_for_compression)
                    cumulative_filters_importance += filters_importance

                # 2. Calculate threshold
                num_of_sparse_elems = get_rounded_pruned_element_number(
                    cumulative_filters_importance.size(0), pruning_rate)
                threshold = sorted(cumulative_filters_importance)[min(
                    num_of_sparse_elems, filters_num[0] - 1)]
                mask = calculate_binary_mask(cumulative_filters_importance,
                                             threshold)

                # 3. Set binary masks for filter
                for minfo in group.nodes:
                    pruning_module = minfo.operand
                    pruning_module.binary_filter_pruning_mask = mask

        # Calculate actual flops with new masks
        self.current_flops = self._calculate_flops_pruned_model_by_masks()
Пример #4
0
    def _set_binary_masks_for_filters(self):
        nncf_logger.debug("Setting new binary masks for pruned modules.")

        with torch.no_grad():
            for minfo in self.pruned_module_info:
                pruning_module = minfo.operand
                # 1. Calculate importance for all filters in all weights
                # 2. Calculate thresholds for every weight
                # 3. Set binary masks for filter
                filters_importance = self.filter_importance(minfo.module.weight)
                num_of_sparse_elems = get_rounded_pruned_element_number(filters_importance.size(0),
                                                                        self.pruning_rate)
                threshold = sorted(filters_importance)[num_of_sparse_elems]
                pruning_module.binary_filter_pruning_mask = calculate_binary_mask(filters_importance, threshold)