Пример #1
0
    def _set_binary_masks_for_filters(self, pruning_rate):
        nncf_logger.debug("Setting new binary masks for pruned modules.")

        with torch.no_grad():
            for group in self.pruned_module_groups_info.get_all_clusters():
                filters_num = torch.tensor(
                    [get_filters_num(minfo.module) for minfo in group.nodes])
                assert torch.all(filters_num == filters_num[0])
                device = group.nodes[0].module.weight.device

                cumulative_filters_importance = torch.zeros(
                    filters_num[0]).to(device)
                # 1. Calculate cumulative importance for all filters in group
                for minfo in group.nodes:
                    filters_importance = self.filter_importance(
                        minfo.module.weight,
                        minfo.module.target_weight_dim_for_compression)
                    cumulative_filters_importance += filters_importance

                # 2. Calculate threshold
                num_of_sparse_elems = get_rounded_pruned_element_number(
                    cumulative_filters_importance.size(0), pruning_rate)
                threshold = sorted(cumulative_filters_importance)[min(
                    num_of_sparse_elems, filters_num[0] - 1)]
                mask = calculate_binary_mask(cumulative_filters_importance,
                                             threshold)

                # 3. Set binary masks for filter
                for minfo in group.nodes:
                    pruning_module = minfo.operand
                    pruning_module.binary_filter_pruning_mask = mask

        # Calculate actual flops with new masks
        self.current_flops = self._calculate_flops_pruned_model_by_masks()
Пример #2
0
    def _set_binary_masks_for_filters(self):
        nncf_logger.debug("Setting new binary masks for pruned modules.")

        with torch.no_grad():
            for minfo in self.pruned_module_info:
                pruning_module = minfo.operand
                # 1. Calculate importance for all filters in all weights
                # 2. Calculate thresholds for every weight
                # 3. Set binary masks for filter
                filters_importance = self.filter_importance(minfo.module.weight)
                num_of_sparse_elems = get_rounded_pruned_element_number(filters_importance.size(0),
                                                                        self.pruning_rate)
                threshold = sorted(filters_importance)[num_of_sparse_elems]
                pruning_module.binary_filter_pruning_mask = calculate_binary_mask(filters_importance, threshold)
Пример #3
0
    def _set_binary_masks_for_all_pruned_modules(self):
        nncf_logger.debug("Setting new binary masks for all pruned modules together.")

        normalized_weights = []
        filter_importances = []
        for minfo in self.pruned_module_info:
            pruning_module = minfo.operand
            # 1. Calculate importance for all filters in all weights
            # 2. Calculate thresholds for every weight
            # 3. Set binary masks for filter
            normalized_weight = self.weights_normalizer(minfo.module.weight)
            normalized_weights.append(normalized_weight)

            filter_importances.append(self.filter_importance(normalized_weight))
        importances = torch.cat(filter_importances)
        threshold = sorted(importances)[int(self.pruning_rate * importances.size(0))]

        for i, minfo in enumerate(self.pruned_module_info):
            pruning_module = minfo.operand
            pruning_module.binary_filter_pruning_mask = calculate_binary_mask(filter_importances[i], threshold)
Пример #4
0
    def _set_binary_masks_for_all_pruned_modules(self, pruning_rate):
        nncf_logger.debug(
            "Setting new binary masks for all pruned modules together.")
        filter_importances = []
        # 1. Calculate importances for all groups of  filters
        for group in self.pruned_module_groups_info.get_all_clusters():
            filters_num = torch.tensor(
                [get_filters_num(minfo.module) for minfo in group.nodes])
            assert torch.all(filters_num == filters_num[0])
            device = group.nodes[0].module.weight.device

            cumulative_filters_importance = torch.zeros(
                filters_num[0]).to(device)
            # Calculate cumulative importance for all filters in this group
            for minfo in group.nodes:
                normalized_weight = self.weights_normalizer(
                    minfo.module.weight)
                filters_importance = self.filter_importance(
                    normalized_weight,
                    minfo.module.target_weight_dim_for_compression)
                cumulative_filters_importance += filters_importance

            filter_importances.append(cumulative_filters_importance)

        # 2. Calculate one threshold for all weights
        importances = torch.cat(filter_importances)
        threshold = sorted(importances)[int(pruning_rate *
                                            importances.size(0))]

        # 3. Set binary masks for filters in grops
        for i, group in enumerate(
                self.pruned_module_groups_info.get_all_clusters()):
            mask = calculate_binary_mask(filter_importances[i], threshold)
            for minfo in group.nodes:
                pruning_module = minfo.operand
                pruning_module.binary_filter_pruning_mask = mask

        # Calculate actual flops with new masks
        self.current_flops = self._calculate_flops_pruned_model_by_masks()