Exemplo n.º 1
0
    def _sparsify_weights(
            self, target_model: NNCFNetwork) -> List[PTInsertionCommand]:
        device = next(target_model.parameters()).device
        sparsified_module_nodes = target_model.get_weighted_original_graph_nodes(
            nncf_module_names=self.compressed_nncf_module_names)
        insertion_commands = []
        for module_node in sparsified_module_nodes:
            node_name = module_node.node_name

            if not self._should_consider_scope(node_name):
                nncf_logger.info(
                    "Ignored adding Weight Sparsifier in scope: {}".format(
                        node_name))
                continue

            nncf_logger.info(
                "Adding Weight Sparsifier in scope: {}".format(node_name))
            compression_lr_multiplier = \
                self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier',
                                                                        self.name)
            operation = self.create_weight_sparsifying_operation(
                module_node, compression_lr_multiplier)
            hook = operation.to(device)
            insertion_commands.append(
                PTInsertionCommand(
                    PTTargetPoint(TargetType.OPERATION_WITH_WEIGHTS,
                                  target_node_name=node_name), hook,
                    TransformationPriority.SPARSIFICATION_PRIORITY))
            sparsified_module = target_model.get_containing_module(node_name)
            self._sparsified_module_info.append(
                SparseModuleInfo(node_name, sparsified_module, hook))

        return insertion_commands
Exemplo n.º 2
0
def add_adjust_padding_nodes(bitwidth_graph: nx.DiGraph,
                             model: NNCFNetwork) -> nx.DiGraph():
    # pylint:disable=protected-access

    NewNodeArgs = namedtuple('NewNodeArgs',
                             ('node_key', 'attr', 'parent_node_key'))
    nncf_graph = model.get_graph()
    args = []
    for node_key in bitwidth_graph.nodes:
        node = nncf_graph.get_node_by_key(node_key)
        module = model.get_containing_module(node.node_name)
        if isinstance(module, NNCFConv2d):
            adjust_padding_ops = filter(
                lambda x: isinstance(x, UpdatePaddingValue),
                module.pre_ops.values())
            for _ in adjust_padding_ops:
                new_node_key = f'{node_key}_apad'
                attr = dict(type='',
                            label='adjust_padding_value',
                            style='filled',
                            color='yellow')
                args.append(NewNodeArgs(new_node_key, attr, node_key))

    for arg in args:
        bitwidth_graph.add_node(arg.node_key, **arg.attr)
        bitwidth_graph.add_edge(arg.node_key, arg.parent_node_key)
    return bitwidth_graph
Exemplo n.º 3
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(input_mask))

        is_depthwise = is_prunable_depthwise_conv(node)
        node_module = model.get_containing_module(node.node_name)
        old_num_channels = int(node_module.weight.size(1))

        if is_depthwise:
            # In depthwise case prune output channels by input mask, here only fix for new number of input channels
            node_module.groups = new_num_channels
            node_module.in_channels = new_num_channels
            old_num_channels = int(node_module.weight.size(0))
        else:
            out_channels = node_module.weight.size(0)
            broadcasted_mask = bool_mask.repeat(out_channels).view(
                out_channels, bool_mask.size(0))
            new_weight_shape = list(node_module.weight.shape)
            new_weight_shape[1] = new_num_channels

            node_module.in_channels = new_num_channels
            node_module.weight = torch.nn.Parameter(
                node_module.weight[broadcasted_mask].view(new_weight_shape))

        nncf_logger.info(
            'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_channels,
                          new_num_channels))
Exemplo n.º 4
0
    def output_prune(cls, model: NNCFNetwork, node: NNCFNode,
                     graph: NNCFGraph) -> None:
        output_mask = node.data['output_mask']
        if output_mask is None:
            return

        bool_mask = torch.tensor(output_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(bool_mask))

        node_module = model.get_containing_module(node.node_name)
        old_num_clannels = int(node_module.weight.size(1))

        in_channels = node_module.weight.size(0)
        broadcasted_mask = bool_mask.repeat(in_channels).view(
            in_channels, bool_mask.size(0))
        new_weight_shape = list(node_module.weight.shape)
        new_weight_shape[1] = new_num_channels

        node_module.out_channels = new_num_channels
        node_module.weight = torch.nn.Parameter(
            node_module.weight[broadcasted_mask].view(new_weight_shape))

        if node_module.bias is not None:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          node_module.out_channels))
Exemplo n.º 5
0
def test_pruning_node_selector(
        test_input_info_struct_: GroupPruningModulesTestStruct):
    model = test_input_info_struct_.model
    non_pruned_module_nodes = test_input_info_struct_.non_pruned_module_nodes
    pruned_groups_by_node_id = test_input_info_struct_.pruned_groups_by_node_id
    prune_first, prune_downsample = test_input_info_struct_.prune_params

    pruning_operations = [v.op_func_name for v in NNCF_PRUNING_MODULES_DICT]
    grouping_operations = PTElementwisePruningOp.get_all_op_aliases()
    from nncf.common.pruning.node_selector import PruningNodeSelector
    pruning_node_selector = PruningNodeSelector(PT_PRUNING_OPERATOR_METATYPES,
                                                pruning_operations,
                                                grouping_operations, None,
                                                None, prune_first,
                                                prune_downsample)
    model = model()
    model.eval()
    nncf_network = NNCFNetwork(model,
                               input_infos=[ModelInputInfo([1, 1, 8, 8])])
    graph = nncf_network.get_original_graph()
    pruning_groups = pruning_node_selector.create_pruning_groups(graph)

    # 1. Check all not pruned modules
    all_pruned_nodes = pruning_groups.get_all_nodes()
    all_pruned_modules = [
        nncf_network.get_containing_module(node.node_name)
        for node in all_pruned_nodes
    ]
    for node_name in non_pruned_module_nodes:
        module = nncf_network.get_containing_module(node_name)
        assert module is not None and module not in all_pruned_modules

    # 2. Check that all pruned groups are valid
    for group_by_id in pruned_groups_by_node_id:
        first_node_id = group_by_id[0]
        cluster = pruning_groups.get_cluster_containing_element(first_node_id)
        cluster_node_ids = [n.node_id for n in cluster.elements]
        cluster_node_ids.sort()

        assert Counter(cluster_node_ids) == Counter(group_by_id)
Exemplo n.º 6
0
def get_bn_for_conv_node_by_name(
        target_model: NNCFNetwork,
        conv_node_name: NNCFNodeName) -> Optional[torch.nn.Module]:
    """
    Returns a batch norm module in target_model that corresponds immediately following a given
    convolution node in the model's NNCFGraph representation.
    :param target_model: NNCFNetwork to work with
    :param module_scope:
    :return: batch norm module
    """
    graph = target_model.get_original_graph()
    conv_node = graph.get_node_by_name(conv_node_name)
    bn_node = get_bn_node_for_conv(graph, conv_node)
    if bn_node is None:
        return None
    bn_module = target_model.get_containing_module(bn_node.node_name)
    return bn_module
Exemplo n.º 7
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)

        node_module = model.get_containing_module(node.node_name)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.in_channels = int(torch.sum(bool_mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          node_module.in_channels))
Exemplo n.º 8
0
    def output_prune(cls, model: NNCFNetwork, node: NNCFNode,
                     graph: NNCFGraph) -> None:
        mask = node.data['output_mask']
        if mask is None:
            return

        bool_mask = torch.tensor(mask, dtype=torch.bool)

        node_module = model.get_containing_module(node.node_name)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.out_channels = int(torch.sum(mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        if node_module.bias is not None:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned Convolution {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          node_module.out_channels))
Exemplo n.º 9
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return

        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        node_module = model.get_containing_module(node.node_name)

        if isinstance(node_module, tuple(NNCF_WRAPPED_USER_MODULES_DICT)):
            assert node_module.target_weight_dim_for_compression == 0, \
                "Implemented only for target_weight_dim_for_compression == 0"
            old_num_clannels = int(node_module.weight.size(0))
            new_num_channels = int(torch.sum(input_mask))
            node_module.weight = torch.nn.Parameter(
                node_module.weight[bool_mask])
            node_module.n_channels = new_num_channels

            nncf_logger.info(
                'Pruned Elementwise {} by input mask. Old num features: {}, new num features:'
                ' {}.'.format(node.data['key'], old_num_clannels,
                              new_num_channels))
Exemplo n.º 10
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return

        node_module = model.get_containing_module(node.node_name)

        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        old_num_clannels = int(node_module.weight.size(0))
        new_num_channels = int(torch.sum(input_mask))

        node_module.num_channels = new_num_channels
        node_module.num_groups = new_num_channels

        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])
        node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned GroupNorm {} by input mask. Old num features: {}, new num features:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          new_num_channels))
Exemplo n.º 11
0
    def __init__(self, algo_ctrl: QuantizationController, model: NNCFNetwork,
                       groups_of_adjacent_quantizers: GroupsOfAdjacentQuantizers,
                       add_flops=False):
        # pylint:disable=too-many-branches
        # pylint:disable=too-many-statements
        nncf_graph = model.get_graph()
        self._nx_graph = nncf_graph.get_graph_for_structure_analysis()
        if add_flops:
            flops_per_module = model.get_flops_per_module()

            flops_vs_node_group = defaultdict(set)  # type: Dict[int, Tuple[int, Set[NNCFNode]]]
            for idx, module_node_name_and_flops in enumerate(flops_per_module.items()):
                module_node_name, flops = module_node_name_and_flops
                node_set = set(nncf_graph.get_op_nodes_in_scope(nncf_graph.get_scope_by_node_name(module_node_name)))
                flops_vs_node_group[idx] = (flops, node_set)

        grouped_mode = bool(groups_of_adjacent_quantizers)
        for node_key in nncf_graph.get_all_node_keys():
            node = nncf_graph.get_node_by_key(node_key)
            color = ''
            operator_name = node.node_type
            module = model.get_containing_module(node.node_name)
            if isinstance(module, NNCFConv2d):
                color = 'lightblue'
                if module.groups == module.in_channels and module.in_channels > 1:
                    operator_name = 'DW_Conv2d'
                    color = 'purple'
                kernel_size = 'x'.join(map(str, module.kernel_size))
                operator_name += f'_k{kernel_size}'
                padding_values = set(module.padding)
                padding_enabled = len(padding_values) >= 1 and padding_values.pop()
                if padding_enabled:
                    operator_name += '_PAD'
                if add_flops:
                    matches = [f_nodes_tpl for idx, f_nodes_tpl in flops_vs_node_group.items()
                               if node in f_nodes_tpl[1]]
                    assert len(matches) == 1
                    flops, affected_nodes = next(iter(matches))
                    operator_name += f'_FLOPS:{str(flops)}'
                    if len(affected_nodes) > 1:
                        node_ids = sorted([n.node_id for n in affected_nodes])
                        operator_name += "(shared among nodes {})".format(",".join(
                            [str(node_id) for node_id in node_ids]))
            operator_name += '_#{}'.format(node.node_id)
            target_node_to_draw = self._nx_graph.nodes[node_key]
            target_node_to_draw['label'] = operator_name
            target_node_to_draw['style'] = 'filled'
            if color:
                target_node_to_draw['color'] = color

        non_weight_quantizers = algo_ctrl.non_weight_quantizers
        bitwidth_color_map = {2: 'purple', 4: 'red', 8: 'green', 6: 'orange'}
        for quantizer_id, quantizer_info in non_weight_quantizers.items():
            self._paint_activation_quantizer_node(nncf_graph, quantizer_id,
                                                  quantizer_info, bitwidth_color_map,
                                                  groups_of_adjacent_quantizers)
        for wq_id, wq_info in algo_ctrl.weight_quantizers.items():
            nodes = [nncf_graph.get_node_by_name(tp.target_node_name)
                     for tp in wq_info.affected_insertions]
            if not nodes:
                raise AttributeError('Failed to get affected nodes for quantized module node: {}'.format(
                    wq_id.target_node_name))
            preds = [nncf_graph.get_previous_nodes(node) for node in nodes]
            wq_nodes = []
            for pred_list in preds:
                for pred_node in pred_list:
                    if 'UpdateWeight' in pred_node.node_name:
                        wq_nodes.append(pred_node)
            assert len(wq_nodes) == 1

            node = wq_nodes[0]
            node_id = node.node_id
            key = nncf_graph.get_node_key_by_id(node_id)
            nx_node_to_draw_upon = self._nx_graph.nodes[key]
            quantizer = wq_info.quantizer_module_ref
            bitwidths = quantizer.num_bits
            nx_node_to_draw_upon['label'] = 'WFQ_[{}]_#{}'.format(quantizer.get_quantizer_config(), str(node_id))
            if grouped_mode:
                group_id_str = 'UNDEFINED'
                group_id = groups_of_adjacent_quantizers.get_group_id_for_quantizer(wq_id)
                if group_id is None:
                    nncf_logger.error('No group for weight quantizer for: {}'.format(wq_id))
                else:
                    group_id_str = str(group_id)
                nx_node_to_draw_upon['label'] += '_G' + group_id_str
            nx_node_to_draw_upon['color'] = bitwidth_color_map[bitwidths]
            nx_node_to_draw_upon['style'] = 'filled'
Exemplo n.º 12
0
    def _prune_weights(self, target_model: NNCFNetwork):
        target_model_graph = target_model.get_original_graph()
        groups_of_nodes_to_prune = self.pruning_node_selector.create_pruning_groups(
            target_model_graph)

        device = next(target_model.parameters()).device
        insertion_commands = []
        self.pruned_module_groups_info = Clusterization[PrunedModuleInfo](
            lambda x: x.node_name)

        for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.elements:
                node_name = node.node_name
                module = target_model.get_containing_module(node_name)
                module_scope = target_model_graph.get_scope_by_node_name(
                    node_name)
                # Check that we need to prune weights in this op
                assert self._is_pruned_module(module)

                nncf_logger.info(
                    "Adding Weight Pruner in scope: {}".format(node_name))
                pruning_block = self.create_weight_pruning_operation(
                    module, node_name)
                # Hook for weights and bias
                hook = UpdateWeightAndBias(pruning_block).to(device)
                insertion_commands.append(
                    PTInsertionCommand(
                        PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
                                      target_node_name=node_name), hook,
                        TransformationPriority.PRUNING_PRIORITY))
                group_minfos.append(
                    PrunedModuleInfo(
                        node_name=node_name,
                        module_scope=module_scope,
                        module=module,
                        operand=pruning_block,
                        node_id=node.node_id,
                        is_depthwise=is_prunable_depthwise_conv(node)))

            cluster = Cluster[PrunedModuleInfo](
                i, group_minfos, [n.node_id for n in group.elements])
            self.pruned_module_groups_info.add_cluster(cluster)

        # Propagate masks to find norm layers to prune
        init_output_masks_in_graph(
            target_model_graph, self.pruned_module_groups_info.get_all_nodes())
        MaskPropagationAlgorithm(
            target_model_graph, PT_PRUNING_OPERATOR_METATYPES,
            PTNNCFPruningTensorProcessor).mask_propagation()

        # Adding binary masks also for Batch/Group Norms to allow applying masks after propagation
        types_to_apply_mask = ['group_norm']
        if self.prune_batch_norms:
            types_to_apply_mask.append('batch_norm')

        all_norm_layers = target_model_graph.get_nodes_by_types(
            types_to_apply_mask)
        for node in all_norm_layers:
            if node.data['output_mask'] is None:
                # Skip elements that will not be pruned
                continue

            node_name = node.node_name
            module = target_model.get_containing_module(node_name)

            pruning_block = self.create_weight_pruning_operation(
                module, node_name)
            # Hook for weights and bias
            hook = UpdateWeightAndBias(pruning_block).to(device)
            insertion_commands.append(
                PTInsertionCommand(
                    PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
                                  target_node_name=node_name), hook,
                    TransformationPriority.PRUNING_PRIORITY))
            self._pruned_norms_operators.append((node, pruning_block, module))
        return insertion_commands