Example #1
0
    def _set_binary_masks_for_pruned_layers_globally(self,
                                                     pruning_level: float):
        """
        Sets the binary mask values for layer groups according to the global pruning level.
        Filter importance scores in each group are merged into a single global list and a
        threshold value separating the pruning_level proportion of the least important filters
        in the model is calculated. Filters are pruned globally according to the threshold value.
        """
        nncf_logger.debug(
            'Setting new binary masks for all pruned modules together.')
        filter_importances = {}
        wrapped_layers = collect_wrapped_layers(self._model)

        # 0. Remove masks at the elements of the NNCFGraph
        for node in self._original_graph.topological_sort():
            node.data.pop('output_mask', None)

        # 1. Calculate masks
        # a. Calculate importances for all groups of filters
        for group in self._pruned_layer_groups_info.get_all_clusters():
            cumulative_filters_importance = self._calculate_filters_importance_in_group(
                group)
            filter_importances[group.id] = cumulative_filters_importance

        # b. Calculate one threshold for all weights
        importances = tf.concat(list(filter_importances.values()), 0)
        threshold = sorted(importances)[int(pruning_level *
                                            importances.shape[0])]

        # c. Initialize masks
        for group in self._pruned_layer_groups_info.get_all_clusters():
            filter_mask = calculate_binary_mask(filter_importances[group.id],
                                                threshold)
            for node in group.elements:
                nncf_node = self._original_graph.get_node_by_id(
                    node.nncf_node_id)
                nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask)

        # 2. Propagate masks across the graph
        mask_propagator = MaskPropagationAlgorithm(
            self._original_graph, TF_PRUNING_OPERATOR_METATYPES,
            TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # 3. Apply masks to the model
        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes if layer.name == n.layer_name
            ][0]
            if nncf_node.data['output_mask'] is not None:
                self._set_operation_masks([layer],
                                          nncf_node.data['output_mask'].tensor)

        # Calculate actual flops with new masks
        self._update_benchmark_statistics()
Example #2
0
    def _set_binary_masks_for_pruned_layers_groupwise(self,
                                                      pruning_level: float):
        nncf_logger.debug('Setting new binary masks for pruned layers.')
        wrapped_layers = collect_wrapped_layers(self._model)

        # 0. Removing masks at the elements of the NNCFGraph
        for node in self._original_graph.topological_sort():
            node.data.pop('output_mask', None)

        # 1. Calculate masks
        for group in self._pruned_layer_groups_info.get_all_clusters():
            # a. Calculate the cumulative importance for all filters in the group
            cumulative_filters_importance = self._calculate_filters_importance_in_group(
                group)
            filters_num = len(cumulative_filters_importance)

            # b. Calculate threshold
            num_of_sparse_elems = get_rounded_pruned_element_number(
                cumulative_filters_importance.shape[0], pruning_level)
            threshold = sorted(cumulative_filters_importance)[min(
                num_of_sparse_elems, filters_num - 1)]

            # c. Initialize masks
            filter_mask = calculate_binary_mask(cumulative_filters_importance,
                                                threshold)
            for node in group.elements:
                nncf_node = self._original_graph.get_node_by_id(
                    node.nncf_node_id)
                nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask)

        # 2. Propagating masks across the graph
        mask_propagator = MaskPropagationAlgorithm(
            self._original_graph, TF_PRUNING_OPERATOR_METATYPES,
            TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # 3. Apply masks to the model
        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes if layer.name == n.layer_name
            ][0]
            if nncf_node.data['output_mask'] is not None:
                self._set_operation_masks([layer],
                                          nncf_node.data['output_mask'].tensor)

        # Calculate actual flops and weights number with new masks
        self._update_benchmark_statistics()
def test_identity_mask_propogation_prune_ops(dummy_op_class):
    assert dummy_op_class.accept_pruned_input(None)
    graph = NNCFGraph()
    conv_op = graph.add_nncf_node('conv_op', 'conv',
                                  dummy_types.DummyConvMetatype)
    identity_ops = []
    for alias in dummy_op_class.get_all_op_aliases():
        identity_op = graph.add_nncf_node(
            'identity', alias, dummy_types.DummyIdentityMaskForwardMetatype)
        graph.add_edge_between_nncf_nodes(from_node_id=conv_op.node_id,
                                          to_node_id=identity_op.node_id,
                                          tensor_shape=[10] * 4,
                                          input_port_id=0,
                                          output_port_id=0,
                                          dtype=Dtype.FLOAT)
        identity_ops.append(identity_op)
    # Check with and without masks
    for output_mask in [None, NPNNCFTensor(np.ones((10, )))]:
        conv_op = graph.get_node_by_id(conv_op.node_id)
        conv_op.data['output_mask'] = output_mask
        MaskPropagationAlgorithm(graph,
                                 dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES,
                                 NPNNCFTensorProcessor).mask_propagation()
        for identity_op in identity_ops:
            identity_op = graph.get_node_by_id(identity_op.node_id)
            assert np.all(identity_op.data['output_mask'] == output_mask)
def test_concat_output_tensor_device():
    graph = NNCFGraph()
    dummy_ops = [
        graph.add_nncf_node(f'dummy_op_{i}', DummyMaskProducerMetatype.name,
                            DummyMaskProducerMetatype) for i in range(3)
    ]
    concat_layer_attributes = MultipleInputLayerAttributes(2)
    concat_node = graph.add_nncf_node('concat_node',
                                      'concat',
                                      dummy_types.DummyConcatMetatype,
                                      layer_attributes=concat_layer_attributes)
    for op in dummy_ops:
        graph.add_edge_between_nncf_nodes(from_node_id=op.node_id,
                                          to_node_id=concat_node.node_id,
                                          tensor_shape=[10] * 4,
                                          input_port_id=0,
                                          output_port_id=0,
                                          dtype=Dtype.FLOAT)

    # Set mask to last dummy node
    ref_device = 'some_test_device'
    for op in dummy_ops[:-1]:
        op = graph.get_node_by_id(op.node_id)
        op.data['output_mask'] = None

    last_op = graph.get_node_by_id(dummy_ops[-1].node_id)
    last_op.data['output_mask'] = NPNNCFTensor(np.ones(10),
                                               dummy_device=ref_device)
    # Propagate masks
    MaskPropagationAlgorithm(graph,
                             dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES,
                             NPNNCFTensorProcessor).mask_propagation()
    # Check concat op has appropriate device
    concat_node = graph.get_node_by_id(concat_node.node_id)
    assert concat_node.data['output_mask'].device == ref_device
def test_group_norm_pruning_ops(num_channels, num_groups,
                                accept_pruned_input_ref):
    graph = NNCFGraph()
    conv_op = graph.add_nncf_node('conv_op', 'conv',
                                  dummy_types.DummyConvMetatype)
    group_norm_layer_attributes = GroupNormLayerAttributes(
        True, num_channels=num_channels, num_groups=num_groups)
    group_norm_op = graph.add_nncf_node(
        'identity',
        dummy_types.DummyGroupNormMetatype.name,
        dummy_types.DummyGroupNormMetatype,
        layer_attributes=group_norm_layer_attributes)
    assert dummy_types.DummyGroupNormPruningOp.accept_pruned_input(
        group_norm_op) == accept_pruned_input_ref
    graph.add_edge_between_nncf_nodes(from_node_id=conv_op.node_id,
                                      to_node_id=group_norm_op.node_id,
                                      tensor_shape=[10] * 4,
                                      input_port_id=0,
                                      output_port_id=0,
                                      dtype=Dtype.FLOAT)
    # Check with and without masks
    for output_mask in [None, NPNNCFTensor(np.ones((10, )))]:
        conv_op = graph.get_node_by_id(conv_op.node_id)
        conv_op.data['output_mask'] = output_mask
        MaskPropagationAlgorithm(graph,
                                 dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES,
                                 NPNNCFTensorProcessor).mask_propagation()
        identity_op = graph.get_node_by_id(group_norm_op.node_id)
        if not accept_pruned_input_ref:
            output_mask = None

        assert np.all(identity_op.data['output_mask'] == output_mask)
def test_conv_pruning_ops(transpose, layer_attributes, ref_accept_pruned_input,
                          conv_type):
    default_conv_params = {
        'weight_requires_grad': True,
        'kernel_size': (2, 2),
        'stride': (1, 1),
        'padding_values': [0, 0]
    }
    graph = NNCFGraph()
    dummy_op_before = graph.add_nncf_node('dummy_op_before',
                                          DummyMaskProducerMetatype.name,
                                          DummyMaskProducerMetatype)
    target_conv_attributes = ConvolutionLayerAttributes(transpose=transpose,
                                                        **layer_attributes,
                                                        **default_conv_params)
    conv_op_target = graph.add_nncf_node(
        'conv_op_target',
        dummy_types.DummyConvMetatype.name,
        dummy_types.DummyConvMetatype,
        layer_attributes=target_conv_attributes)
    graph.add_edge_between_nncf_nodes(
        from_node_id=dummy_op_before.node_id,
        to_node_id=conv_op_target.node_id,
        tensor_shape=[layer_attributes['in_channels']] * 4,
        input_port_id=0,
        output_port_id=0,
        dtype=Dtype.FLOAT)
    pruning_op_class = dummy_types.DummyTransposeConvPruningOp if transpose else dummy_types.DummyConvPruningOp
    assert pruning_op_class.accept_pruned_input(
        conv_op_target) == ref_accept_pruned_input
    ones_input_mask = NPNNCFTensor(np.ones(
        (layer_attributes['in_channels'], )))
    ones_output_mask = NPNNCFTensor(
        np.ones((layer_attributes['out_channels'], )))
    # Check all combinations of masks
    for input_mask in [None, ones_input_mask]:
        for output_mask in [None, ones_output_mask]:
            dummy_op_before = graph.get_node_by_id(dummy_op_before.node_id)
            conv_op_target = graph.get_node_by_id(conv_op_target.node_id)
            dummy_op_before.data['output_mask'] = input_mask
            conv_op_target.data['output_mask'] = output_mask
            MaskPropagationAlgorithm(
                graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES,
                NPNNCFTensorProcessor).mask_propagation()
            dummy_op_before = graph.get_node_by_id(dummy_op_before.node_id)
            conv_op_target = graph.get_node_by_id(conv_op_target.node_id)
            if conv_type == 'usual_conv':
                assert np.all(
                    conv_op_target.data['output_mask'] == output_mask)
            elif conv_type in [
                    'grouped_conv_no_depthwise', 'multiply_grouped_conv'
            ]:
                assert conv_op_target.data['output_mask'] is None
            else:
                assert np.all(conv_op_target.data['output_mask'] == input_mask)
Example #7
0
def test_symbolic_mask_propagation(test_input_info_struct_):
    model = test_input_info_struct_.model()
    prune_first, *_ = test_input_info_struct_.prune_params
    nncf_model, _ = create_nncf_model_and_pruning_builder(
        model, {'prune_first_conv': prune_first})
    pruning_types = [v.op_func_name for v in NNCF_PRUNING_MODULES_DICT]
    graph = nncf_model.get_graph()
    algo = MaskPropagationAlgorithm(graph, PT_PRUNING_OPERATOR_METATYPES,
                                    SymbolicMaskProcessor)
    final_can_prune = algo.symbolic_mask_propagation(
        pruning_types, test_input_info_struct_.can_prune_after_analysis)
    # Check all output masks are deleted
    for node in graph.get_all_nodes():
        assert node.data['output_mask'] is None

    # Check ref decisions
    ref_final_can_prune = test_input_info_struct_.final_can_prune
    assert len(final_can_prune) == len(ref_final_can_prune)
    for idx in final_can_prune:
        assert final_can_prune[idx] == ref_final_can_prune[idx]
def test_elementwise_prune_ops(valid_masks):
    graph = NNCFGraph()
    conv_op_0 = graph.add_nncf_node('conv_op_0',
                                    dummy_types.DummyConvMetatype.name,
                                    dummy_types.DummyConvMetatype)
    conv_op_1 = graph.add_nncf_node('conv_op_1',
                                    dummy_types.DummyConvMetatype.name,
                                    dummy_types.DummyConvMetatype)
    elementwise_op = graph.add_nncf_node(
        'elementwise', dummy_types.DummyElementwiseMetatype.name,
        dummy_types.DummyElementwiseMetatype)
    add_node = partial(graph.add_edge_between_nncf_nodes,
                       tensor_shape=[10] * 4,
                       input_port_id=0,
                       output_port_id=0,
                       dtype=Dtype.FLOAT)
    # conv_op_0 -> elementwise
    add_node(from_node_id=conv_op_0.node_id, to_node_id=elementwise_op.node_id)

    # conv_op_1 -> elementwise
    add_node(from_node_id=conv_op_1.node_id, to_node_id=elementwise_op.node_id)

    masks = [NPNNCFTensor(np.ones(
        (10, ))), NPNNCFTensor(np.ones(
            (10, )))] if valid_masks is not None else [None, None]

    def set_masks(masks, ops):
        for conv_op, mask in zip(ops, masks):
            conv_op = graph.get_node_by_id(conv_op.node_id)
            conv_op.data['output_mask'] = mask

    if valid_masks is None or valid_masks:
        if valid_masks:
            set_masks(masks, [conv_op_0, conv_op_1])
        MaskPropagationAlgorithm(graph,
                                 dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES,
                                 NPNNCFTensorProcessor).mask_propagation()
        elementwise_op = graph.get_node_by_id(elementwise_op.node_id)
        assert np.all(elementwise_op.data['output_mask'] == masks[0])
    else:

        def check_wrong_masks(masks):
            with pytest.raises(AssertionError):
                set_masks(masks, [conv_op_0, conv_op_1])
                MaskPropagationAlgorithm(
                    graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES,
                    NPNNCFTensorProcessor).mask_propagation()

        masks[0].tensor[0] = 0
        check_wrong_masks(masks)
        masks[0] = NPNNCFTensorProcessor.concatenate(
            [masks[1], NPNNCFTensor(np.array([1]))], axis=0)
        check_wrong_masks(masks)
Example #9
0
    def _pruning_dimensions_analysis(
            self, graph,
            can_prune_after_check) -> Dict[int, PruningAnalysisDecision]:
        """
        Check all nodes that were marked as prunable after the model analysis and compatibility check vs.
        pruning algo have a correct correspondent closing node on each path form self to outputs.

        :param graph: Graph to work with.
        :param can_prune_after_check: Dict of node indices vs the decision made by previous steps;
            the decision is true only for the nodes that do not conflict with mask propagation and
            are supported by the NNCF pruning algorithm
        :return: Pruning node analysis after model analyzer, pruning algo compatibility and pruning dimensions checks.
        """
        mask_prop_algo = MaskPropagationAlgorithm(
            graph, self._pruning_operator_metatypes)
        can_prune_by_dim = mask_prop_algo.symbolic_mask_propagation(
            self._prune_operations_types, can_prune_after_check)

        can_prune_for_prunable_layers = \
            {node_id: can_prune_after_check[node_id].join(can_prune_by_dim[node_id]) for node_id in can_prune_by_dim}

        can_prune_updated = can_prune_after_check.copy()
        can_prune_updated.update(can_prune_for_prunable_layers)
        return can_prune_updated
Example #10
0
    def _propagate_masks(self):
        nncf_logger.debug("Propagating pruning masks")
        # 1. Propagate masks for all modules
        graph = self.model.get_original_graph()

        init_output_masks_in_graph(
            graph, self.pruned_module_groups_info.get_all_nodes())
        MaskPropagationAlgorithm(
            graph, PT_PRUNING_OPERATOR_METATYPES,
            PTNNCFPruningTensorProcessor).mask_propagation()

        # 2. Set the masks for Batch/Group Norms
        pruned_node_modules = []
        for node, pruning_block, node_module in self._pruned_norms_operators:
            if node_module not in pruned_node_modules:
                # Setting masks for BN nodes
                pruning_block.binary_filter_pruning_mask = node.data[
                    'output_mask'].tensor
                pruned_node_modules.append(node_module)
def test_convs_elementwise_source_before_concat(empty_mask_right_branch,
                                                empty_mask_left_branch,
                                                right_branch_output_channels):
    graph = NNCFGraph()
    conv_op_0 = graph.add_nncf_node('conv_op_0', 'conv',
                                    dummy_types.DummyConvMetatype)
    conv_op_1 = graph.add_nncf_node('conv_op_1', 'conv',
                                    dummy_types.DummyConvMetatype)
    conv_op_2 = graph.add_nncf_node('conv_op_2', 'conv',
                                    dummy_types.DummyConvMetatype)
    elementwise_node = graph.add_nncf_node(
        'elementwise_node', 'elementwise',
        dummy_types.DummyElementwiseMetatype)
    concat_layer_attributes = MultipleInputLayerAttributes(2)
    concat_node = graph.add_nncf_node('concat_node',
                                      'concat',
                                      dummy_types.DummyConcatMetatype,
                                      layer_attributes=concat_layer_attributes)
    add_node = partial(graph.add_edge_between_nncf_nodes,
                       input_port_id=0,
                       output_port_id=0,
                       dtype=Dtype.FLOAT)

    # conv_op_0 -> elementwise_node
    add_node(from_node_id=conv_op_0.node_id,
             to_node_id=elementwise_node.node_id,
             tensor_shape=[10] * 4)

    # conv_op_1 -> elementwise_node
    add_node(from_node_id=conv_op_1.node_id,
             to_node_id=elementwise_node.node_id,
             tensor_shape=[10] * 4)

    # elementwise_node -> concat_node
    add_node(from_node_id=elementwise_node.node_id,
             to_node_id=concat_node.node_id,
             tensor_shape=[10] * 4)

    # conv_op_2 -> concat_node
    add_node(from_node_id=conv_op_2.node_id,
             to_node_id=concat_node.node_id,
             tensor_shape=[10, 10, right_branch_output_channels, 10])

    # Set masks
    if not empty_mask_left_branch:
        for conv_op in [conv_op_0, conv_op_1]:
            conv_op = graph.get_node_by_id(conv_op.node_id)
            conv_op.data['output_mask'] = NPNNCFTensor(np.ones(10))

    if not empty_mask_right_branch:
        conv_op = graph.get_node_by_id(conv_op_2.node_id)
        conv_op.data['output_mask'] = NPNNCFTensor(
            np.ones(right_branch_output_channels))

    # Propagate masks
    MaskPropagationAlgorithm(graph,
                             dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES,
                             NPNNCFTensorProcessor).mask_propagation()
    # Check with masks
    concat_node = graph.get_node_by_id(concat_node.node_id)
    if empty_mask_left_branch and empty_mask_right_branch:
        assert concat_node.data['output_mask'] is None
    else:
        reference_mask = np.ones((10 + right_branch_output_channels, ))
        np.testing.assert_equal(concat_node.data['output_mask'].tensor,
                                reference_mask)
 def check_wrong_masks(masks):
     with pytest.raises(AssertionError):
         set_masks(masks, [conv_op_0, conv_op_1])
         MaskPropagationAlgorithm(
             graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES,
             NPNNCFTensorProcessor).mask_propagation()
Example #13
0
    def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLayout:
        """
        Computes necessary model transformations (pruning mask insertions) to enable pruning.

        :param model: The original uncompressed model.
        :return: The instance of the `TransformationLayout` class containing
            a list of pruning mask insertions.
        """
        converter = TFModelConverterFactory.create(model)
        self._graph = converter.convert()
        groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(self._graph)

        transformations = TFTransformationLayout()
        shared_layers = set()

        self._pruned_layer_groups_info = Clusterization[PrunedLayerInfo](lambda x: x.layer_name)

        for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.elements:
                layer_name = get_layer_identifier(node)
                layer = model.get_layer(layer_name)
                group_minfos.append(PrunedLayerInfo(node.node_name, layer_name, node.node_id,
                                                    is_prunable_depthwise_conv(node)))

                # Add output_mask to elements to run mask_propagation
                # and detect spec_nodes that will be pruned.
                # It should be done for all elements of shared layer.
                node.data['output_mask'] = TFNNCFTensor(tf.ones(node.layer_attributes.out_channels))
                if layer_name in shared_layers:
                    continue
                if node.is_shared():
                    shared_layers.add(layer_name)
                # Check that we need to prune weights in this op
                assert self._is_pruned_layer(layer)
                nncf_logger.info('Adding Weight Pruner in: %s', layer_name)

                _, layer_info = converter.get_layer_info_for_node(node.node_name)
                for weight_def in node.metatype.weight_definitions:
                    transformations.register(
                        self._get_insertion_command_binary_mask(
                            layer_info.layer_name, weight_def.weight_attr_name)
                    )
                if node.metatype.bias_attr_name is not None and \
                        getattr(layer, node.metatype.bias_attr_name) is not None:
                    transformations.register(
                        self._get_insertion_command_binary_mask(
                            layer_info.layer_name, node.metatype.bias_attr_name)
                    )

            cluster = Cluster[PrunedLayerInfo](i, group_minfos, [n.node_id for n in group.elements])
            self._pruned_layer_groups_info.add_cluster(cluster)

        # Propagating masks across the graph to detect spec_nodes that will be pruned
        mask_propagator = MaskPropagationAlgorithm(self._graph, TF_PRUNING_OPERATOR_METATYPES,
                                                   TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # Add masks for all spec modules, because prunable batchnorm layers can be determined
        # at the moment of mask propagation
        types_spec_layers = [TFBatchNormalizationLayerMetatype] \
            if self._prune_batch_norms else []

        spec_nodes = self._graph.get_nodes_by_metatypes(types_spec_layers)
        for spec_node in spec_nodes:
            layer_name = get_layer_identifier(spec_node)
            layer = model.get_layer(layer_name)
            if spec_node.data['output_mask'] is None:
                # Skip elements that will not be pruned
                continue
            if layer_name in shared_layers:
                continue
            if spec_node.is_shared():
                shared_layers.add(layer_name)
            nncf_logger.info('Adding Weight Pruner in: %s', layer_name)

            _, layer_info = converter.get_layer_info_for_node(spec_node.node_name)
            for weight_def in spec_node.metatype.weight_definitions:
                if spec_node.metatype is TFBatchNormalizationLayerMetatype \
                        and not layer.scale and weight_def.weight_attr_name == 'gamma':
                    nncf_logger.debug('Fused gamma parameter encountered in BatchNormalization layer. '
                                      'Do not add mask to it.')
                    continue

                transformations.register(
                    self._get_insertion_command_binary_mask(
                        layer_info.layer_name, weight_def.weight_attr_name)
                )
            transformations.register(
                self._get_insertion_command_binary_mask(
                    layer_info.layer_name, spec_node.metatype.bias_attr_name)
            )
        return transformations
Example #14
0
    def _set_binary_masks_for_pruned_modules_globally_by_flops_target(
            self, target_flops_pruning_rate: float):
        """
        Prunes least important filters one-by-one until target FLOPs pruning rate is achieved.
        Filters are sorted by filter importance score.
        """
        nncf_logger.debug('Setting new binary masks for pruned layers.')
        target_flops = self.full_flops * (1 - target_flops_pruning_rate)
        wrapped_layers = collect_wrapped_layers(self._model)
        masks = []

        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes
                if layer.layer.name == get_original_name(n.node_name)
            ][0]
            nncf_node.data['output_mask'] = tf.ones(get_filters_num(layer))

        # 1. Calculate importances for all groups of filters. Initialize masks.
        filter_importances = []
        group_indexes = []
        filter_indexes = []
        for group in self._pruned_layer_groups_info.get_all_clusters():
            cumulative_filters_importance = \
                self._calculate_filters_importance_in_group(group, wrapped_layers)

            filter_importances.extend(cumulative_filters_importance)
            filters_num = len(cumulative_filters_importance)
            group_indexes.extend([group.id] * filters_num)
            filter_indexes.extend(range(filters_num))
            masks[group.id] = tf.ones(filters_num)

        # 2.
        tmp_in_channels = self._layers_in_channels.copy()
        tmp_out_channels = self._layers_out_channels.copy()
        sorted_importances = sorted(zip(filter_importances, group_indexes,
                                        filter_indexes),
                                    key=lambda x: x[0])
        for _, group_id, filter_index in sorted_importances:
            if self._pruning_quotas[group_id] == 0:
                continue
            masks[group_id][filter_index] = 0
            self._pruning_quotas[group_id] -= 1

            # Update input/output shapes of pruned nodes
            group = self._pruned_layer_groups_info.get_cluster_by_id(group_id)
            for node in group.nodes:
                tmp_out_channels[node.layer_name] -= 1
            for node_name in self._next_nodes[group_id]:
                tmp_in_channels[node_name] -= 1

            flops = sum(
                count_flops_for_nodes(self._original_graph,
                                      self._layers_in_shapes,
                                      self._layers_out_shapes,
                                      input_channels=tmp_in_channels,
                                      output_channels=tmp_out_channels,
                                      conv_op_types=GENERAL_CONV_LAYERS,
                                      linear_op_types=LINEAR_LAYERS).values())
            if flops <= target_flops:
                # 3. Add masks to the graph and propagate them
                for group in self._pruned_layer_groups_info.get_all_clusters():
                    for node in group.nodes:
                        nncf_node = self._original_graph.get_node_by_id(
                            node.nncf_node_id)
                        nncf_node.data['output_mask'] = masks[group.id]

                mask_propagator = MaskPropagationAlgorithm(
                    self._original_graph, TF_PRUNING_OPERATOR_METATYPES)
                mask_propagator.mask_propagation()

                # 4. Set binary masks to the model
                self.current_flops = flops
                nncf_sorted_nodes = self._original_graph.topological_sort()
                for layer in wrapped_layers:
                    nncf_node = [
                        n for n in nncf_sorted_nodes
                        if layer.layer.name == get_original_name(n.node_name)
                    ][0]
                    if nncf_node.data['output_mask'] is not None:
                        self._set_operation_masks(
                            [layer], nncf_node.data['output_mask'])
                return
        raise RuntimeError(
            f'Unable to prune model to required flops pruning rate:'
            f' {target_flops_pruning_rate}')
Example #15
0
    def get_transformation_layout(
            self, model: tf.keras.Model) -> TFTransformationLayout:
        """
        Computes necessary model transformations (pruning mask insertions) to enable pruning.

        :param model: The original uncompressed model.
        :return: The instance of the `TransformationLayout` class containing
            a list of pruning mask insertions.
        """
        self._graph = convert_keras_model_to_nncf_graph(model)
        groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(
            self._graph)

        transformations = TFTransformationLayout()
        shared_layers = set()

        self._pruned_layer_groups_info = Clusterization('layer_name')

        for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.nodes:
                layer_name = get_layer_identifier(node)
                layer = model.get_layer(layer_name)

                # Add output_mask to nodes to run mask_propagation
                # and detect spec_nodes that will be pruned.
                # It should be done for all nodes of shared layer.
                node.data['output_mask'] = tf.ones(
                    node.module_attributes.out_channels)
                if layer_name in shared_layers:
                    continue
                if is_shared(node):
                    shared_layers.add(layer_name)
                # Check that we need to prune weights in this op
                assert self._is_pruned_layer(layer)
                nncf_logger.info('Adding Weight Pruner in: %s', layer_name)
                for attr_name_key in [WEIGHT_ATTR_NAME, BIAS_ATTR_NAME]:
                    attr_name = LAYERS_WITH_WEIGHTS[
                        node.node_type][attr_name_key]
                    if getattr(layer, attr_name) is not None:
                        transformations.register(
                            self._get_insertion_command_binary_mask(
                                layer_name, attr_name))
                group_minfos.append(PrunedLayerInfo(layer_name, node.node_id))

            cluster = NodesCluster(i, group_minfos,
                                   [n.node_id for n in group.nodes])
            self._pruned_layer_groups_info.add_cluster(cluster)

        # Propagating masks across the graph to detect spec_nodes that will be pruned
        mask_propagator = MaskPropagationAlgorithm(
            self._graph, TF_PRUNING_OPERATOR_METATYPES)
        mask_propagator.mask_propagation()

        # Add masks for all spec modules, because prunable batchnorm layers can be determines
        # at the moment of mask propagation
        types_spec_layers = list(SPECIAL_LAYERS_WITH_WEIGHTS)
        if not self._prune_batch_norms:
            types_spec_layers.remove('BatchNormalization')

        spec_nodes = self._graph.get_nodes_by_types(types_spec_layers)
        for spec_node in spec_nodes:
            layer_name = get_layer_identifier(spec_node)
            if spec_node.data['output_mask'] is None:
                # Skip nodes that will not be pruned
                continue
            if layer_name in shared_layers:
                continue
            if is_shared(spec_node):
                shared_layers.add(layer_name)
            nncf_logger.info('Adding Weight Pruner in: %s', layer_name)
            for attr_name_key in [WEIGHT_ATTR_NAME, BIAS_ATTR_NAME]:
                attr_name = SPECIAL_LAYERS_WITH_WEIGHTS[
                    spec_node.node_type][attr_name_key]
                transformations.register(
                    self._get_insertion_command_binary_mask(
                        layer_name, attr_name))
        return transformations
Example #16
0
    def _prune_weights(self, target_model: NNCFNetwork):
        target_model_graph = target_model.get_original_graph()
        groups_of_nodes_to_prune = self.pruning_node_selector.create_pruning_groups(
            target_model_graph)

        device = next(target_model.parameters()).device
        insertion_commands = []
        self.pruned_module_groups_info = Clusterization[PrunedModuleInfo](
            lambda x: x.node_name)

        for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.elements:
                node_name = node.node_name
                module = target_model.get_containing_module(node_name)
                module_scope = target_model_graph.get_scope_by_node_name(
                    node_name)
                # Check that we need to prune weights in this op
                assert self._is_pruned_module(module)

                nncf_logger.info(
                    "Adding Weight Pruner in scope: {}".format(node_name))
                pruning_block = self.create_weight_pruning_operation(
                    module, node_name)
                # Hook for weights and bias
                hook = UpdateWeightAndBias(pruning_block).to(device)
                insertion_commands.append(
                    PTInsertionCommand(
                        PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
                                      target_node_name=node_name), hook,
                        TransformationPriority.PRUNING_PRIORITY))
                group_minfos.append(
                    PrunedModuleInfo(
                        node_name=node_name,
                        module_scope=module_scope,
                        module=module,
                        operand=pruning_block,
                        node_id=node.node_id,
                        is_depthwise=is_prunable_depthwise_conv(node)))

            cluster = Cluster[PrunedModuleInfo](
                i, group_minfos, [n.node_id for n in group.elements])
            self.pruned_module_groups_info.add_cluster(cluster)

        # Propagate masks to find norm layers to prune
        init_output_masks_in_graph(
            target_model_graph, self.pruned_module_groups_info.get_all_nodes())
        MaskPropagationAlgorithm(
            target_model_graph, PT_PRUNING_OPERATOR_METATYPES,
            PTNNCFPruningTensorProcessor).mask_propagation()

        # Adding binary masks also for Batch/Group Norms to allow applying masks after propagation
        types_to_apply_mask = ['group_norm']
        if self.prune_batch_norms:
            types_to_apply_mask.append('batch_norm')

        all_norm_layers = target_model_graph.get_nodes_by_types(
            types_to_apply_mask)
        for node in all_norm_layers:
            if node.data['output_mask'] is None:
                # Skip elements that will not be pruned
                continue

            node_name = node.node_name
            module = target_model.get_containing_module(node_name)

            pruning_block = self.create_weight_pruning_operation(
                module, node_name)
            # Hook for weights and bias
            hook = UpdateWeightAndBias(pruning_block).to(device)
            insertion_commands.append(
                PTInsertionCommand(
                    PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
                                  target_node_name=node_name), hook,
                    TransformationPriority.PRUNING_PRIORITY))
            self._pruned_norms_operators.append((node, pruning_block, module))
        return insertion_commands