Ejemplo n.º 1
0
 def _collect_all_weights(self):
     all_weights = []
     for wrapped_layer in collect_wrapped_layers(self._model):
         for weight_attr, ops in wrapped_layer.weights_attr_ops.items():
             for op_name in ops:
                 if op_name in self._op_names:
                     all_weights.append(tf.reshape(
                         self._weight_importance_fn(wrapped_layer.layer_weights[weight_attr]),
                         [-1]))
     return all_weights
Ejemplo n.º 2
0
    def _set_masks_for_threshold(self, threshold_val):
        for wrapped_layer in collect_wrapped_layers(self._model):
            for weight_attr, ops in wrapped_layer.weights_attr_ops.items():
                weight = wrapped_layer.layer_weights[weight_attr]

                for op_name, op in ops.items():
                    if isinstance(op, BinaryMask):
                        wrapped_layer.ops_weights[op_name].assign(
                            calc_magnitude_binary_mask(weight,
                                                       self.weight_importance,
                                                       threshold_val))
Ejemplo n.º 3
0
    def _set_binary_masks_for_pruned_layers_globally(self,
                                                     pruning_rate: float):
        """
        Sets the binary mask values for layer groups according to the global pruning rate.
        Filter importance scores in each group are merged into a single global list and a
        threshold value separating the pruning_rate proportion of the least important filters
        in the model is calculated. Filters are pruned globally according to the threshold value.
        """
        nncf_logger.debug(
            'Setting new binary masks for all pruned modules together.')
        filter_importances = {}
        wrapped_layers = collect_wrapped_layers(self._model)

        # 0. Remove masks at the nodes of the NNCFGraph
        for node in self._original_graph.topological_sort():
            node.data.pop('output_mask', None)

        # 1. Calculate masks
        # a. Calculate importances for all groups of filters
        for group in self._pruned_layer_groups_info.get_all_clusters():
            cumulative_filters_importance = \
                self._calculate_filters_importance_in_group(group, wrapped_layers)
            filter_importances[group.id] = cumulative_filters_importance

        # b. Calculate one threshold for all weights
        importances = tf.concat(list(filter_importances.values()), 0)
        threshold = sorted(importances)[int(pruning_rate *
                                            importances.shape[0])]

        # c. Initialize masks
        for group in self._pruned_layer_groups_info.get_all_clusters():
            filter_mask = calculate_binary_mask(filter_importances[group.id],
                                                threshold)
            for node in group.nodes:
                nncf_node = self._original_graph.get_node_by_id(
                    node.nncf_node_id)
                nncf_node.data['output_mask'] = filter_mask

        # 2. Propagate masks across the graph
        mask_propagator = MaskPropagationAlgorithm(
            self._original_graph, TF_PRUNING_OPERATOR_METATYPES)
        mask_propagator.mask_propagation()

        # 3. Apply masks to the model
        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes
                if layer.layer.name == get_original_name(n.node_name)
            ][0]
            if nncf_node.data['output_mask'] is not None:
                self._set_operation_masks([layer],
                                          nncf_node.data['output_mask'])
Ejemplo n.º 4
0
 def _collect_all_weights(self):
     all_weights = []
     for wrapped_layer in collect_wrapped_layers(self._model):
         for weight_attr, ops in wrapped_layer.weights_attr_ops.items():
             for op in ops.values():
                 if isinstance(op, BinaryMask):
                     all_weights.append(
                         tf.reshape(
                             self.weight_importance(
                                 wrapped_layer.layer_weights[weight_attr]),
                             [-1]))
     return all_weights
Ejemplo n.º 5
0
    def _set_masks_for_threshold(self, threshold_val):
        for wrapped_layer in collect_wrapped_layers(self._model):
            for weight_attr, ops in wrapped_layer.weights_attr_ops.items():
                weight = wrapped_layer.layer_weights[weight_attr]

                for op_name in ops:
                    if op_name in self._op_names:
                        wrapped_layer.ops_weights[op_name]['mask'].assign(
                            calc_magnitude_binary_mask(weight,
                                                       self._weight_importance_fn,
                                                       threshold_val)
                        )
Ejemplo n.º 6
0
def apply_fn_to_op_weights(model: tf.keras.Model,
                           op_names: List[str],
                           fn=lambda x: x):
    sparsifyed_layers = collect_wrapped_layers(model)
    target_ops = []
    for layer in sparsifyed_layers:
        for ops in layer.weights_attr_ops.values():
            for op in ops.values():
                if op.name in op_names:
                    weight = layer.get_operation_weights(op.name)
                    target_ops.append((op, fn(weight)))
    return target_ops
Ejemplo n.º 7
0
    def raw_statistics(self):
        raw_sparsity_statistics = {}
        sparsity_levels = []
        mask_names = []
        weights_shapes = []
        weights_numbers = []
        total_weights_number = tf.constant(0)
        total_sparsified_weights_number = tf.constant(0)
        total_bkup_weights_number = tf.constant(0)
        wrapped_layers = collect_wrapped_layers(self._model)
        for wrapped_layer in wrapped_layers:
            for ops in wrapped_layer.weights_attr_ops.values():
                for op_name, op in ops.items():
                    if op_name in self._op_names:
                        if isinstance(op, BinaryMaskWithWeightsBackup):
                            total_bkup_weights_number += tf.size(op.bkup_var)
                        if isinstance(op, BinaryMask):
                            mask = wrapped_layer.ops_weights[op_name]['mask']
                            mask_names.append(mask.name)
                            weights_shapes.append(list(mask.shape))
                            weights_number = tf.size(mask)
                            weights_numbers.append(weights_number)
                            sparsified_weights_number = weights_number - tf.reduce_sum(tf.cast(mask, tf.int32))
                            sparsity_levels.append(sparsified_weights_number / weights_number)
                            total_weights_number += weights_number
                            total_sparsified_weights_number += sparsified_weights_number

        sparsity_rate_for_sparsified_modules = (total_sparsified_weights_number / total_weights_number).numpy()
        model_weights_number = count_params(self._model.weights) - total_weights_number - total_bkup_weights_number
        sparsity_rate_for_model = (total_sparsified_weights_number / model_weights_number).numpy()

        raw_sparsity_statistics.update({
            'sparsity_rate_for_sparsified_modules': sparsity_rate_for_sparsified_modules,
            'sparsity_rate_for_model': sparsity_rate_for_model,
            'sparsity_threshold': self._threshold
        })

        sparsity_levels = tf.keras.backend.batch_get_value(sparsity_levels)
        weights_percentages = [weights_number / total_weights_number * 100
                               for weights_number in weights_numbers]
        weights_percentages = tf.keras.backend.batch_get_value(weights_percentages)
        mask_sparsity = list(zip(mask_names, weights_shapes, sparsity_levels, weights_percentages))
        raw_sparsity_statistics['sparsity_statistic_by_layer'] = []
        for mask_name, weights_shape, sparsity_level, weights_percentage in mask_sparsity:
            raw_sparsity_statistics['sparsity_statistic_by_layer'].append({
                'Name': mask_name,
                'Weight\'s Shape': weights_shape,
                'SR': sparsity_level,
                '% weights': weights_percentage
            })

        return raw_sparsity_statistics
Ejemplo n.º 8
0
    def raw_statistics(self) -> Dict[str, object]:
        raw_pruning_statistics = {}
        pruning_rates = []
        mask_names = []
        weights_shapes = []
        mask_shapes = []
        wrapped_layers = collect_wrapped_layers(self._model)
        for wrapped_layer in wrapped_layers:
            for weight_attr, ops in wrapped_layer.weights_attr_ops.items():
                for op_name in ops:
                    if op_name in self._op_names:
                        mask = wrapped_layer.ops_weights[op_name]['mask']
                        mask_names.append(mask.name)
                        weights_shapes.append(list(mask.shape))
                        reduce_axes = list(range(len(mask.shape)))
                        filter_axis = get_filter_axis(wrapped_layer,
                                                      weight_attr)
                        if filter_axis == -1:
                            filter_axis = reduce_axes[filter_axis]
                        reduce_axes.remove(filter_axis)
                        filter_mask = tf.reduce_max(tf.cast(mask, tf.int32),
                                                    axis=reduce_axes,
                                                    keepdims=True)
                        mask_shapes.append(list(filter_mask.shape))
                        filters_number = get_filters_num(wrapped_layer)
                        pruned_filters_number = filters_number - tf.reduce_sum(
                            filter_mask)
                        pruning_rates.append(pruned_filters_number /
                                             filters_number)

        raw_pruning_statistics.update({'pruning_rate': self.pruning_rate})

        pruning_rates = tf.keras.backend.batch_get_value(pruning_rates)

        mask_pruning = list(
            zip(mask_names, weights_shapes, mask_shapes, pruning_rates))
        raw_pruning_statistics['pruning_statistic_by_layer'] = []
        for mask_name, weights_shape, mask_shape, pruning_rate in mask_pruning:
            raw_pruning_statistics['pruning_statistic_by_layer'].append({
                'Name':
                mask_name,
                'Weight\'s Shape':
                weights_shape,
                'Mask Shape':
                mask_shape,
                'PR':
                pruning_rate
            })

        return raw_pruning_statistics
Ejemplo n.º 9
0
    def _set_binary_masks_for_pruned_layers_groupwise(self,
                                                      pruning_rate: float):
        nncf_logger.debug('Setting new binary masks for pruned layers.')
        wrapped_layers = collect_wrapped_layers(self._model)

        # 0. Removing masks at the nodes of the NNCFGraph
        for node in self._original_graph.topological_sort():
            node.data.pop('output_mask', None)

        # 1. Calculate masks
        for group in self._pruned_layer_groups_info.get_all_clusters():
            # a. Calculate the cumulative importance for all filters in the group
            cumulative_filters_importance = \
                self._calculate_filters_importance_in_group(group, wrapped_layers)
            filters_num = len(cumulative_filters_importance)

            # b. Calculate threshold
            num_of_sparse_elems = get_rounded_pruned_element_number(
                cumulative_filters_importance.shape[0], pruning_rate)
            threshold = sorted(cumulative_filters_importance)[min(
                num_of_sparse_elems, filters_num - 1)]

            # c. Initialize masks
            filter_mask = calculate_binary_mask(cumulative_filters_importance,
                                                threshold)
            for node in group.nodes:
                nncf_node = self._original_graph.get_node_by_id(
                    node.nncf_node_id)
                nncf_node.data['output_mask'] = filter_mask

        # 2. Propagating masks across the graph
        mask_propagator = MaskPropagationAlgorithm(
            self._original_graph, TF_PRUNING_OPERATOR_METATYPES)
        mask_propagator.mask_propagation()

        # 3. Apply masks to the model
        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes
                if layer.layer.name == get_original_name(n.node_name)
            ][0]
            if nncf_node.data['output_mask'] is not None:
                self._set_operation_masks([layer],
                                          nncf_node.data['output_mask'])
Ejemplo n.º 10
0
    def _set_binary_masks_for_pruned_modules_globally_by_flops_target(
            self, target_flops_pruning_rate: float):
        """
        Prunes least important filters one-by-one until target FLOPs pruning rate is achieved.
        Filters are sorted by filter importance score.
        """
        nncf_logger.debug('Setting new binary masks for pruned layers.')
        target_flops = self.full_flops * (1 - target_flops_pruning_rate)
        wrapped_layers = collect_wrapped_layers(self._model)
        masks = []

        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes
                if layer.layer.name == get_original_name(n.node_name)
            ][0]
            nncf_node.data['output_mask'] = tf.ones(get_filters_num(layer))

        # 1. Calculate importances for all groups of filters. Initialize masks.
        filter_importances = []
        group_indexes = []
        filter_indexes = []
        for group in self._pruned_layer_groups_info.get_all_clusters():
            cumulative_filters_importance = \
                self._calculate_filters_importance_in_group(group, wrapped_layers)

            filter_importances.extend(cumulative_filters_importance)
            filters_num = len(cumulative_filters_importance)
            group_indexes.extend([group.id] * filters_num)
            filter_indexes.extend(range(filters_num))
            masks[group.id] = tf.ones(filters_num)

        # 2.
        tmp_in_channels = self._layers_in_channels.copy()
        tmp_out_channels = self._layers_out_channels.copy()
        sorted_importances = sorted(zip(filter_importances, group_indexes,
                                        filter_indexes),
                                    key=lambda x: x[0])
        for _, group_id, filter_index in sorted_importances:
            if self._pruning_quotas[group_id] == 0:
                continue
            masks[group_id][filter_index] = 0
            self._pruning_quotas[group_id] -= 1

            # Update input/output shapes of pruned nodes
            group = self._pruned_layer_groups_info.get_cluster_by_id(group_id)
            for node in group.nodes:
                tmp_out_channels[node.layer_name] -= 1
            for node_name in self._next_nodes[group_id]:
                tmp_in_channels[node_name] -= 1

            flops = sum(
                count_flops_for_nodes(self._original_graph,
                                      self._layers_in_shapes,
                                      self._layers_out_shapes,
                                      input_channels=tmp_in_channels,
                                      output_channels=tmp_out_channels,
                                      conv_op_types=GENERAL_CONV_LAYERS,
                                      linear_op_types=LINEAR_LAYERS).values())
            if flops <= target_flops:
                # 3. Add masks to the graph and propagate them
                for group in self._pruned_layer_groups_info.get_all_clusters():
                    for node in group.nodes:
                        nncf_node = self._original_graph.get_node_by_id(
                            node.nncf_node_id)
                        nncf_node.data['output_mask'] = masks[group.id]

                mask_propagator = MaskPropagationAlgorithm(
                    self._original_graph, TF_PRUNING_OPERATOR_METATYPES)
                mask_propagator.mask_propagation()

                # 4. Set binary masks to the model
                self.current_flops = flops
                nncf_sorted_nodes = self._original_graph.topological_sort()
                for layer in wrapped_layers:
                    nncf_node = [
                        n for n in nncf_sorted_nodes
                        if layer.layer.name == get_original_name(n.node_name)
                    ][0]
                    if nncf_node.data['output_mask'] is not None:
                        self._set_operation_masks(
                            [layer], nncf_node.data['output_mask'])
                return
        raise RuntimeError(
            f'Unable to prune model to required flops pruning rate:'
            f' {target_flops_pruning_rate}')
Ejemplo n.º 11
0
    def raw_statistics(self):
        raw_sparsity_statistics = {}
        sparsity_levels = []
        mask_names = []
        weights_shapes = []
        weights_numbers = []
        sparse_prob_sum = tf.constant(0.)
        total_weights_number = tf.constant(0)
        total_sparsified_weights_number = tf.constant(0)
        wrapped_layers = collect_wrapped_layers(self._model)
        for wrapped_layer in wrapped_layers:
            for ops in wrapped_layer.weights_attr_ops.values():
                for op_name in ops:
                    if op_name in self._op_names:
                        mask = wrapped_layer.ops_weights[op_name]['mask']
                        sw_loss = tf.reduce_sum(binary_mask(mask))
                        weights_number = tf.size(mask)
                        sparsified_weights_number = weights_number - tf.cast(
                            sw_loss, tf.int32)
                        mask_names.append(wrapped_layer.name + '_rb_mask')
                        weights_shapes.append(list(mask.shape))
                        weights_numbers.append(weights_number)
                        sparsity_levels.append(sparsified_weights_number /
                                               weights_number)
                        sparse_prob_sum += tf.math.reduce_sum(
                            tf.math.sigmoid(mask))
                        total_weights_number += weights_number
                        total_sparsified_weights_number += sparsified_weights_number

        sparsity_rate_for_sparsified_modules = (
            total_sparsified_weights_number / total_weights_number).numpy()
        model_weights_number = count_params(
            self._model.weights) - total_weights_number
        sparsity_rate_for_model = (total_sparsified_weights_number /
                                   model_weights_number).numpy()
        mean_sparse_prob = (sparse_prob_sum /
                            tf.cast(total_weights_number, tf.float32)).numpy()

        raw_sparsity_statistics.update({
            'sparsity_rate_for_sparsified_modules':
            sparsity_rate_for_sparsified_modules,
            'sparsity_rate_for_model':
            sparsity_rate_for_model,
            'mean_sparse_prob':
            mean_sparse_prob,
            'target_sparsity_rate':
            self.loss.target_sparsity_rate,
        })

        sparsity_levels = tf.keras.backend.batch_get_value(sparsity_levels)
        weights_percentages = [
            weights_number / total_weights_number * 100
            for weights_number in weights_numbers
        ]
        weights_percentages = tf.keras.backend.batch_get_value(
            weights_percentages)
        mask_sparsity = list(
            zip(mask_names, weights_shapes, sparsity_levels,
                weights_percentages))
        raw_sparsity_statistics['sparsity_statistic_by_layer'] = []
        for mask_name, weights_shape, sparsity_level, weights_percentage in mask_sparsity:
            raw_sparsity_statistics['sparsity_statistic_by_layer'].append({
                'Name':
                mask_name,
                'Weight\'s Shape':
                weights_shape,
                'SR':
                sparsity_level,
                '% weights':
                weights_percentage
            })

        return raw_sparsity_statistics