Beispiel #1
0
    def _calculate_flops_and_weights_in_uniformly_pruned_model(
            self, pruning_level: float) -> Tuple[int, int]:
        """
        Prune all prunable modules in model by pruning_level level and returns number of weights and
        flops of the pruned model.

        :param pruning_level: proportion of zero filters in all modules
        :return: flops number in pruned model
        """
        tmp_in_channels, tmp_out_channels = \
            calculate_in_out_channels_in_uniformly_pruned_model(
                pruning_groups=self.pruned_module_groups_info.get_all_clusters(),
                pruning_level=pruning_level,
                full_input_channels=self._modules_in_channels,
                full_output_channels=self._modules_out_channels,
                pruning_groups_next_nodes=self.next_nodes)

        return count_flops_and_weights(
            self._model.get_original_graph(),
            self._modules_in_shapes,
            self._modules_out_shapes,
            input_channels=tmp_in_channels,
            output_channels=tmp_out_channels,
            conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
            linear_op_metatypes=LINEAR_LAYER_METATYPES)
Beispiel #2
0
    def _calculate_flops_and_weights_in_uniformly_pruned_model(
            self, pruning_level):
        tmp_in_channels, tmp_out_channels = \
            calculate_in_out_channels_in_uniformly_pruned_model(
                pruning_groups=self._pruned_layer_groups_info.get_all_clusters(),
                pruning_level=pruning_level,
                full_input_channels=self._layers_in_channels,
                full_output_channels=self._layers_out_channels,
                pruning_groups_next_nodes=self._next_nodes)

        return count_flops_and_weights(
            self._original_graph,
            self._layers_in_shapes,
            self._layers_out_shapes,
            input_channels=tmp_in_channels,
            output_channels=tmp_out_channels,
            conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
            linear_op_metatypes=LINEAR_LAYER_METATYPES)
Beispiel #3
0
    def _update_benchmark_statistics(self):
        tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks(
            pruning_groups=self.pruned_module_groups_info.get_all_clusters(),
            masks=self._collect_pruning_masks(),
            tensor_processor=PTNNCFCollectorTensorProcessor,
            full_input_channels=self._modules_in_channels,
            full_output_channels=self._modules_out_channels,
            pruning_groups_next_nodes=self.next_nodes)

        self.current_filters_num = count_filters_num(
            self._model.get_original_graph(),
            op_metatypes=GENERAL_CONV_LAYER_METATYPES,
            output_channels=tmp_out_channels)

        self.current_flops, self.current_params_num = \
            count_flops_and_weights(self._model.get_original_graph(),
                                    self._modules_in_shapes,
                                    self._modules_out_shapes,
                                    input_channels=tmp_in_channels,
                                    output_channels=tmp_out_channels,
                                    conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
                                    linear_op_metatypes=LINEAR_LAYER_METATYPES)
Beispiel #4
0
def test_calculation_of_flops(all_weights, pruning_flops_target, ref_flops,
                              ref_params_num):
    """
    Test for pruning masks check (_set_binary_masks_for_filters, _set_binary_masks_for_all_filters_together).
    :param all_weights: whether mask will be calculated for all weights in common or not
    :param pruning_flops_target: prune model by flops, if None then by number of channels
    :param ref_flops: reference size of model
    """
    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['params']['all_weights'] = all_weights
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['pruning_init'] = 0.5
    if pruning_flops_target:
        config['compression']['params'][
            'pruning_flops_target'] = pruning_flops_target

    _, pruning_algo, _ = create_pruning_algo_with_config(config)

    assert pruning_algo.current_flops == ref_flops
    assert pruning_algo.current_params_num == ref_params_num
    # pylint:disable=protected-access
    tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks(
        pruning_algo.pruned_module_groups_info.get_all_clusters(),
        masks=pruning_algo._collect_pruning_masks(),
        tensor_processor=PTNNCFCollectorTensorProcessor,
        full_input_channels=pruning_algo._modules_in_channels,
        full_output_channels=pruning_algo._modules_out_channels,
        pruning_groups_next_nodes=pruning_algo.next_nodes)

    cur_flops, cur_params_num = count_flops_and_weights(
        pruning_algo._model.get_original_graph(),
        pruning_algo._modules_in_shapes,
        pruning_algo._modules_out_shapes,
        input_channels=tmp_in_channels,
        output_channels=tmp_out_channels,
        conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
        linear_op_metatypes=LINEAR_LAYER_METATYPES)
    assert (cur_flops, cur_params_num) == (ref_flops, ref_params_num)
def test_flops_calulation_for_spec_layers(model_fn, all_weights, pruning_flops_target,
                                          ref_full_flops, ref_current_flops,
                                          ref_full_params, ref_current_params):
    config = get_basic_pruning_config(8)
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['pruning_init'] = pruning_flops_target
    config['compression']['params']['pruning_flops_target'] = pruning_flops_target
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['params']['all_weights'] = all_weights
    input_shape = [1, 8, 8, 1]
    model = model_fn(input_shape)
    model.compile()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)

    assert compression_ctrl.full_flops == ref_full_flops
    assert compression_ctrl.full_params_num == ref_full_params
    assert compression_ctrl.current_flops == ref_current_flops
    assert compression_ctrl.current_params_num == ref_current_params
    # pylint:disable=protected-access
    tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks(
        pruning_groups=compression_ctrl._pruned_layer_groups_info.get_all_clusters(),
        masks=compression_ctrl._collect_pruning_masks(),
        tensor_processor=TFNNCFCollectorTensorProcessor,
        full_input_channels=compression_ctrl._layers_in_channels,
        full_output_channels=compression_ctrl._layers_out_channels,
        pruning_groups_next_nodes=compression_ctrl._next_nodes)

    cur_flops, cur_params_num = \
        count_flops_and_weights(compression_ctrl._original_graph,
                                compression_ctrl._layers_in_shapes,
                                compression_ctrl._layers_out_shapes,
                                input_channels=tmp_in_channels,
                                output_channels=tmp_out_channels,
                                conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
                                linear_op_metatypes=LINEAR_LAYER_METATYPES)
    assert (cur_flops, cur_params_num) == (ref_current_flops, ref_current_params)
Beispiel #6
0
    def _set_binary_masks_for_pruned_modules_globally_by_flops_target(
            self, target_flops_pruning_level: float) -> None:
        """
        Sorting all prunable filters in the network by importance and pruning the amount of the
        least important filters sufficient to achieve the target pruning level by flops.
        Filters are pruned one-by-one and the corresponding flops value is checked.

        :param target_flops_pruning_level: target proportion of flops removed from the model
        :return:
        """
        target_flops = self.full_flops * (1 - target_flops_pruning_level)

        # 1. Initialize masks
        for minfo in self.pruned_module_groups_info.get_all_nodes():
            new_mask = torch.ones(get_filters_num(minfo.module)).to(
                minfo.module.weight.device)
            self.set_mask(minfo, new_mask)

        # 2. Calculate filter importances for all prunable groups
        filter_importances = []
        cluster_indexes = []
        filter_indexes = []

        for cluster in self.pruned_module_groups_info.get_all_clusters():
            filters_num = torch.tensor(
                [get_filters_num(minfo.module) for minfo in cluster.elements])
            assert torch.all(filters_num == filters_num[0])
            device = cluster.elements[0].module.weight.device

            cumulative_filters_importance = torch.zeros(
                filters_num[0]).to(device)
            # Calculate cumulative importance for all filters in this group
            for minfo in cluster.elements:
                weight = minfo.module.weight
                if self.normalize_weights:
                    weight = self.weights_normalizer(weight)
                filters_importance = self.filter_importance(
                    weight, minfo.module.target_weight_dim_for_compression)
                scaled_importance = self.ranking_coeffs[minfo.node_name][0] * filters_importance + \
                                    self.ranking_coeffs[minfo.node_name][1]
                cumulative_filters_importance += scaled_importance

            filter_importances.append(cumulative_filters_importance)
            cluster_indexes.append(
                cluster.id * torch.ones_like(cumulative_filters_importance))
            filter_indexes.append(
                torch.arange(len(cumulative_filters_importance)))

        importances = torch.cat(filter_importances)
        cluster_indexes = torch.cat(cluster_indexes)
        filter_indexes = torch.cat(filter_indexes)

        # 3. Sort all filter groups by importances and prune the least important filters
        # until target flops pruning level is achieved
        sorted_importances = sorted(zip(importances, cluster_indexes,
                                        filter_indexes),
                                    key=lambda x: x[0])
        cur_num = 0
        tmp_in_channels = self._modules_in_channels.copy()
        tmp_out_channels = self._modules_out_channels.copy()
        tmp_pruning_quotas = self.pruning_quotas.copy()

        while cur_num < len(sorted_importances):
            cluster_idx = int(sorted_importances[cur_num][1])
            filter_idx = int(sorted_importances[cur_num][2])

            if tmp_pruning_quotas[cluster_idx] > 0:
                tmp_pruning_quotas[cluster_idx] -= 1
            else:
                cur_num += 1
                continue

            cluster = self.pruned_module_groups_info.get_cluster_by_id(
                cluster_idx)
            for node in cluster.elements:
                tmp_out_channels[node.node_name] -= 1
                if node.is_depthwise:
                    tmp_in_channels[node.node_name] -= 1

                node.operand.binary_filter_pruning_mask[filter_idx] = 0

            # Prune in channels in all next nodes
            next_nodes = self.next_nodes[cluster.id]
            for node_id in next_nodes:
                tmp_in_channels[node_id] -= 1

            flops, params_num = count_flops_and_weights(
                self._model.get_original_graph(),
                self._modules_in_shapes,
                self._modules_out_shapes,
                input_channels=tmp_in_channels,
                output_channels=tmp_out_channels,
                conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
                linear_op_metatypes=LINEAR_LAYER_METATYPES)
            if flops < target_flops:
                self.current_flops = flops
                self.current_params_num = params_num
                return
            cur_num += 1
        raise RuntimeError("Can't prune model to asked flops pruning level")
Beispiel #7
0
    def _set_binary_masks_for_pruned_modules_globally_by_flops_target(
            self, target_flops_pruning_level: float):
        """
        Prunes least important filters one-by-one until target FLOPs pruning level is achieved.
        Filters are sorted by filter importance score.
        """
        nncf_logger.debug('Setting new binary masks for pruned layers.')
        target_flops = self.full_flops * (1 - target_flops_pruning_level)
        wrapped_layers = collect_wrapped_layers(self._model)
        masks = {}

        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes if layer.name == n.layer_name
            ][0]
            nncf_node.data['output_mask'] = TFNNCFTensor(
                tf.ones(get_filters_num(layer)))

        # 1. Calculate importances for all groups of filters. Initialize masks.
        filter_importances = []
        group_indexes = []
        filter_indexes = []
        for group in self._pruned_layer_groups_info.get_all_clusters():
            cumulative_filters_importance = self._calculate_filters_importance_in_group(
                group)
            filter_importances.extend(cumulative_filters_importance)
            filters_num = len(cumulative_filters_importance)
            group_indexes.extend([group.id] * filters_num)
            filter_indexes.extend(range(filters_num))
            masks[group.id] = tf.ones(filters_num)

        # 2.
        tmp_in_channels = self._layers_in_channels.copy()
        tmp_out_channels = self._layers_out_channels.copy()
        sorted_importances = sorted(zip(filter_importances, group_indexes,
                                        filter_indexes),
                                    key=lambda x: x[0])
        for _, group_id, filter_index in sorted_importances:
            if self._pruning_quotas[group_id] == 0:
                continue
            masks[group_id] = tf.tensor_scatter_nd_update(
                masks[group_id], [[filter_index]], [0])
            self._pruning_quotas[group_id] -= 1

            # Update input/output shapes of pruned elements
            group = self._pruned_layer_groups_info.get_cluster_by_id(group_id)
            for node in group.elements:
                tmp_out_channels[node.node_name] -= 1
                if node.is_depthwise:
                    tmp_in_channels[node.node_name] -= 1

            for node_name in self._next_nodes[group_id]:
                tmp_in_channels[node_name] -= 1

            flops, params_num = count_flops_and_weights(
                self._original_graph,
                self._layers_in_shapes,
                self._layers_out_shapes,
                input_channels=tmp_in_channels,
                output_channels=tmp_out_channels,
                conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
                linear_op_metatypes=LINEAR_LAYER_METATYPES)
            if flops <= target_flops:
                # 3. Add masks to the graph and propagate them
                for group in self._pruned_layer_groups_info.get_all_clusters():
                    for node in group.elements:
                        nncf_node = self._original_graph.get_node_by_id(
                            node.nncf_node_id)
                        nncf_node.data['output_mask'] = TFNNCFTensor(
                            masks[group.id])

                mask_propagator = MaskPropagationAlgorithm(
                    self._original_graph, TF_PRUNING_OPERATOR_METATYPES,
                    TFNNCFPruningTensorProcessor)
                mask_propagator.mask_propagation()

                # 4. Set binary masks to the model
                self.current_flops = flops
                self.current_params_num = params_num
                nncf_sorted_nodes = self._original_graph.topological_sort()
                for layer in wrapped_layers:
                    nncf_node = [
                        n for n in nncf_sorted_nodes
                        if layer.name == n.layer_name
                    ][0]
                    if nncf_node.data['output_mask'] is not None:
                        self._set_operation_masks(
                            [layer], nncf_node.data['output_mask'].tensor)
                return
        raise RuntimeError(
            f'Unable to prune model to required flops pruning level:'
            f' {target_flops_pruning_level}')