Example #1
0
    def _is_module_prunable(self, graph: NNCFGraph,
                            node: NNCFNode) -> PruningAnalysisDecision:
        """
        Check whether we should prune module corresponding to provided node
        according to algorithm parameters.

        :param graph: Graph to work with.
        :param node: Node to check.
        :return: Pruning analysis decision.
        """
        stop_propagation_ops = self._stop_propagation_op_metatype.get_all_op_aliases(
        )
        types_to_track = self._prune_operations_types + stop_propagation_ops
        input_non_pruned_nodes = get_first_nodes_of_type(graph, types_to_track)
        node_name = node.node_name

        if not should_consider_scope(node_name, self._ignored_scopes,
                                     self._target_scopes):
            return PruningAnalysisDecision(False,
                                           PruningAnalysisReason.IGNORED_SCOPE)

        if not self._prune_first and node in input_non_pruned_nodes:
            return PruningAnalysisDecision(False,
                                           PruningAnalysisReason.FIRST_CONV)

        if is_grouped_conv(node) and not is_prunable_depthwise_conv(node):
            return PruningAnalysisDecision(False,
                                           PruningAnalysisReason.GROUP_CONV)

        if not self._prune_downsample_convs and is_conv_with_downsampling(
                node):
            return PruningAnalysisDecision(
                False, PruningAnalysisReason.DOWNSAMPLE_CONV)

        return PruningAnalysisDecision(True)
Example #2
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(input_mask))

        is_depthwise = is_prunable_depthwise_conv(node)
        node_module = model.get_containing_module(node.node_name)
        old_num_channels = int(node_module.weight.size(1))

        if is_depthwise:
            # In depthwise case prune output channels by input mask, here only fix for new number of input channels
            node_module.groups = new_num_channels
            node_module.in_channels = new_num_channels
            old_num_channels = int(node_module.weight.size(0))
        else:
            out_channels = node_module.weight.size(0)
            broadcasted_mask = bool_mask.repeat(out_channels).view(
                out_channels, bool_mask.size(0))
            new_weight_shape = list(node_module.weight.shape)
            new_weight_shape[1] = new_num_channels

            node_module.in_channels = new_num_channels
            node_module.weight = torch.nn.Parameter(
                node_module.weight[broadcasted_mask].view(new_weight_shape))

        nncf_logger.info(
            'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_channels,
                          new_num_channels))
Example #3
0
    def mask_propagation(
            cls, node: NNCFNode, graph: NNCFGraph,
            tensor_processor: Type[NNCFPruningBaseTensorProcessor]) -> None:
        input_masks = get_input_masks(node, graph)
        output_mask = node.data.get('output_mask', None)

        if is_grouped_conv(node):
            output_mask = None
            if is_prunable_depthwise_conv(node):
                output_mask = input_masks[0]

        node.data['output_mask'] = output_mask
Example #4
0
    def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLayout:
        """
        Computes necessary model transformations (pruning mask insertions) to enable pruning.

        :param model: The original uncompressed model.
        :return: The instance of the `TransformationLayout` class containing
            a list of pruning mask insertions.
        """
        converter = TFModelConverterFactory.create(model)
        self._graph = converter.convert()
        groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(self._graph)

        transformations = TFTransformationLayout()
        shared_layers = set()

        self._pruned_layer_groups_info = Clusterization[PrunedLayerInfo](lambda x: x.layer_name)

        for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.elements:
                layer_name = get_layer_identifier(node)
                layer = model.get_layer(layer_name)
                group_minfos.append(PrunedLayerInfo(node.node_name, layer_name, node.node_id,
                                                    is_prunable_depthwise_conv(node)))

                # Add output_mask to elements to run mask_propagation
                # and detect spec_nodes that will be pruned.
                # It should be done for all elements of shared layer.
                node.data['output_mask'] = TFNNCFTensor(tf.ones(node.layer_attributes.out_channels))
                if layer_name in shared_layers:
                    continue
                if node.is_shared():
                    shared_layers.add(layer_name)
                # Check that we need to prune weights in this op
                assert self._is_pruned_layer(layer)
                nncf_logger.info('Adding Weight Pruner in: %s', layer_name)

                _, layer_info = converter.get_layer_info_for_node(node.node_name)
                for weight_def in node.metatype.weight_definitions:
                    transformations.register(
                        self._get_insertion_command_binary_mask(
                            layer_info.layer_name, weight_def.weight_attr_name)
                    )
                if node.metatype.bias_attr_name is not None and \
                        getattr(layer, node.metatype.bias_attr_name) is not None:
                    transformations.register(
                        self._get_insertion_command_binary_mask(
                            layer_info.layer_name, node.metatype.bias_attr_name)
                    )

            cluster = Cluster[PrunedLayerInfo](i, group_minfos, [n.node_id for n in group.elements])
            self._pruned_layer_groups_info.add_cluster(cluster)

        # Propagating masks across the graph to detect spec_nodes that will be pruned
        mask_propagator = MaskPropagationAlgorithm(self._graph, TF_PRUNING_OPERATOR_METATYPES,
                                                   TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # Add masks for all spec modules, because prunable batchnorm layers can be determined
        # at the moment of mask propagation
        types_spec_layers = [TFBatchNormalizationLayerMetatype] \
            if self._prune_batch_norms else []

        spec_nodes = self._graph.get_nodes_by_metatypes(types_spec_layers)
        for spec_node in spec_nodes:
            layer_name = get_layer_identifier(spec_node)
            layer = model.get_layer(layer_name)
            if spec_node.data['output_mask'] is None:
                # Skip elements that will not be pruned
                continue
            if layer_name in shared_layers:
                continue
            if spec_node.is_shared():
                shared_layers.add(layer_name)
            nncf_logger.info('Adding Weight Pruner in: %s', layer_name)

            _, layer_info = converter.get_layer_info_for_node(spec_node.node_name)
            for weight_def in spec_node.metatype.weight_definitions:
                if spec_node.metatype is TFBatchNormalizationLayerMetatype \
                        and not layer.scale and weight_def.weight_attr_name == 'gamma':
                    nncf_logger.debug('Fused gamma parameter encountered in BatchNormalization layer. '
                                      'Do not add mask to it.')
                    continue

                transformations.register(
                    self._get_insertion_command_binary_mask(
                        layer_info.layer_name, weight_def.weight_attr_name)
                )
            transformations.register(
                self._get_insertion_command_binary_mask(
                    layer_info.layer_name, spec_node.metatype.bias_attr_name)
            )
        return transformations
Example #5
0
    def create_pruning_groups(self,
                              graph: NNCFGraph) -> Clusterization[NNCFNode]:
        """
        This function groups ALL nodes with pruning types to groups that should be pruned together.
            1. Create clusters for special ops (eltwises) that should be pruned together
            2. Create groups of nodes that should be pruned together (taking into account clusters of special ops)
            3. Add remaining single nodes
            4. Unite clusters for Conv + Depthwise conv (should be pruned together too)
            5. Checks for groups (all nodes in group can prune or all group can't be pruned)
        Return groups of modules that should be pruned together.

        :param graph: Graph to work with and their initialization parameters as values.
        :return: Clusterization of pruned nodes.
        """
        # pylint:disable=too-many-branches
        all_nodes_to_prune = graph.get_nodes_by_types(
            self._prune_operations_types)  # NNCFNodes here

        # 1. Clusters for special ops
        identity_like_types = self._identity_mask_propagation_op_metatype.get_all_op_aliases(
        )
        special_ops_clusterization = cluster_special_ops(
            graph, self._grouping_operations_types, identity_like_types)

        pruned_nodes_clusterization = Clusterization[NNCFNode](
            lambda x: x.node_id)

        # 2. Clusters for nodes that should be pruned together (taking into account clusters for special ops)
        for i, cluster in enumerate(
                special_ops_clusterization.get_all_clusters()):
            all_pruned_inputs = {}
            clusters_to_merge = []

            for node in cluster.elements:
                sources = get_sources_of_node(node, graph,
                                              self._prune_operations_types)
                for source_node in sources:
                    if pruned_nodes_clusterization.is_node_in_clusterization(
                            source_node.node_id):
                        # Merge clusters if some node already added in another cluster
                        cluster = pruned_nodes_clusterization.get_cluster_containing_element(
                            source_node.node_id)
                        clusters_to_merge.append(cluster.id)
                    elif source_node.node_id not in all_pruned_inputs:
                        all_pruned_inputs[source_node.node_id] = source_node

            if all_pruned_inputs:
                cluster = Cluster[NNCFNode](i, all_pruned_inputs.values(),
                                            all_pruned_inputs.keys())
                clusters_to_merge.append(cluster.id)
                pruned_nodes_clusterization.add_cluster(cluster)

            # Merge clusters if one source node in several special ops clusters
            pruned_nodes_clusterization.merge_list_of_clusters(
                clusters_to_merge)

        last_cluster_idx = len(special_ops_clusterization.get_all_clusters())

        # 3. Add remaining single nodes as separate clusters
        for node in all_nodes_to_prune:
            if not pruned_nodes_clusterization.is_node_in_clusterization(
                    node.node_id):
                cluster = Cluster[NNCFNode](last_cluster_idx, [node],
                                            [node.node_id])
                pruned_nodes_clusterization.add_cluster(cluster)

                last_cluster_idx += 1

        stop_propagation_ops = self._stop_propagation_op_metatype.get_all_op_aliases(
        )
        # 4. Merge clusters for Conv + Depthwise conv (should be pruned together too)
        for node in all_nodes_to_prune:
            cluster_id = pruned_nodes_clusterization.get_cluster_containing_element(
                node.node_id).id

            if is_prunable_depthwise_conv(node):
                previous_convs = get_previous_convs(
                    graph, node, self._prune_operations_types,
                    stop_propagation_ops)
                previous_clusters = [
                    pruned_nodes_clusterization.get_cluster_containing_element(
                        node.node_id).id for node in previous_convs
                ]
                pruned_nodes_clusterization.merge_list_of_clusters(
                    [cluster_id] + previous_clusters)

        # 5. Merge nodes into one cluster if some module forwards several times
        multiforward_nodes = self._get_multiforward_nodes(graph)
        for list_of_nodes in multiforward_nodes:
            clusters_to_merge = [
                pruned_nodes_clusterization.get_cluster_containing_element(
                    node.node_id).id for node in list_of_nodes
            ]
            pruned_nodes_clusterization.merge_list_of_clusters(
                clusters_to_merge)

            # Merge previous convolutions into one cluster
            all_previous_convs = []
            for node in list_of_nodes:
                nncf_node = graph.get_node_by_id(node.node_id)
                previous_convs = get_previous_convs(
                    graph, nncf_node, self._prune_operations_types,
                    stop_propagation_ops)
                # Check if previous node isn't multiforward,
                # in case of multiforward nodes cycle
                for previous_conv in previous_convs:
                    if previous_conv not in list_of_nodes:
                        all_previous_convs.append(previous_conv)

            previous_clusters = [
                pruned_nodes_clusterization.get_cluster_containing_element(
                    node.node_id).id for node in all_previous_convs
            ]
            pruned_nodes_clusterization.merge_list_of_clusters(
                previous_clusters)

        # 6. Checks for groups (all nodes in group can be pruned or all group can't be pruned).
        model_analyser = ModelAnalyzer(graph, self._pruning_operator_metatypes,
                                       is_prunable_depthwise_conv)
        can_prune_analysis = model_analyser.analyse_model_before_pruning()
        can_prune_and_should_prune_analysis = self._should_prune_groups_analysis(
            graph, pruned_nodes_clusterization, can_prune_analysis)
        can_prune_final_analysis = self._pruning_dimensions_analysis(
            graph, can_prune_and_should_prune_analysis)
        self._filter_groups(pruned_nodes_clusterization,
                            can_prune_final_analysis)
        return pruned_nodes_clusterization
Example #6
0
    def _prune_weights(self, target_model: NNCFNetwork):
        target_model_graph = target_model.get_original_graph()
        groups_of_nodes_to_prune = self.pruning_node_selector.create_pruning_groups(
            target_model_graph)

        device = next(target_model.parameters()).device
        insertion_commands = []
        self.pruned_module_groups_info = Clusterization[PrunedModuleInfo](
            lambda x: x.node_name)

        for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.elements:
                node_name = node.node_name
                module = target_model.get_containing_module(node_name)
                module_scope = target_model_graph.get_scope_by_node_name(
                    node_name)
                # Check that we need to prune weights in this op
                assert self._is_pruned_module(module)

                nncf_logger.info(
                    "Adding Weight Pruner in scope: {}".format(node_name))
                pruning_block = self.create_weight_pruning_operation(
                    module, node_name)
                # Hook for weights and bias
                hook = UpdateWeightAndBias(pruning_block).to(device)
                insertion_commands.append(
                    PTInsertionCommand(
                        PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
                                      target_node_name=node_name), hook,
                        TransformationPriority.PRUNING_PRIORITY))
                group_minfos.append(
                    PrunedModuleInfo(
                        node_name=node_name,
                        module_scope=module_scope,
                        module=module,
                        operand=pruning_block,
                        node_id=node.node_id,
                        is_depthwise=is_prunable_depthwise_conv(node)))

            cluster = Cluster[PrunedModuleInfo](
                i, group_minfos, [n.node_id for n in group.elements])
            self.pruned_module_groups_info.add_cluster(cluster)

        # Propagate masks to find norm layers to prune
        init_output_masks_in_graph(
            target_model_graph, self.pruned_module_groups_info.get_all_nodes())
        MaskPropagationAlgorithm(
            target_model_graph, PT_PRUNING_OPERATOR_METATYPES,
            PTNNCFPruningTensorProcessor).mask_propagation()

        # Adding binary masks also for Batch/Group Norms to allow applying masks after propagation
        types_to_apply_mask = ['group_norm']
        if self.prune_batch_norms:
            types_to_apply_mask.append('batch_norm')

        all_norm_layers = target_model_graph.get_nodes_by_types(
            types_to_apply_mask)
        for node in all_norm_layers:
            if node.data['output_mask'] is None:
                # Skip elements that will not be pruned
                continue

            node_name = node.node_name
            module = target_model.get_containing_module(node_name)

            pruning_block = self.create_weight_pruning_operation(
                module, node_name)
            # Hook for weights and bias
            hook = UpdateWeightAndBias(pruning_block).to(device)
            insertion_commands.append(
                PTInsertionCommand(
                    PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
                                  target_node_name=node_name), hook,
                    TransformationPriority.PRUNING_PRIORITY))
            self._pruned_norms_operators.append((node, pruning_block, module))
        return insertion_commands
Example #7
0
 def accept_pruned_input(cls, node: NNCFNode) -> bool:
     if is_grouped_conv(node) and not is_prunable_depthwise_conv(node):
         return False
     return True