Example #1
0
    def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        old_num_clannels = int(node_module.weight.size(0))
        new_num_channels = int(torch.sum(input_mask))

        node_module.num_features = new_num_channels
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])
        node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])
        node_module.running_mean = torch.nn.Parameter(
            node_module.running_mean[bool_mask], requires_grad=False)
        node_module.running_var = torch.nn.Parameter(
            node_module.running_var[bool_mask], requires_grad=False)

        nncf_logger.info(
            'Pruned BatchNorm {} by input mask. Old num features: {}, new num features:'
            ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
Example #2
0
    def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(input_mask))

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        is_depthwise = nx_node['is_depthwise']
        old_num_clannels = int(node_module.weight.size(1))

        if is_depthwise:
            # In depthwise case prune output channels by input mask, here only fix for new number of input channels
            node_module.groups = new_num_channels
            node_module.in_channels = new_num_channels
        else:
            out_channels = node_module.weight.size(0)
            broadcasted_mask = bool_mask.repeat(out_channels).view(
                out_channels, bool_mask.size(0))
            new_weight_shape = list(node_module.weight.shape)
            new_weight_shape[1] = new_num_channels

            node_module.in_channels = new_num_channels
            node_module.weight = torch.nn.Parameter(
                node_module.weight[broadcasted_mask].view(new_weight_shape))

        nncf_logger.info(
            'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
Example #3
0
    def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                         nx_graph: nx.DiGraph):
        output_mask = None
        accept_pruned_input = True
        is_depthwise = False
        input_masks = get_input_masks(nx_node, nx_graph)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        if node_module.pre_ops:
            output_mask = node_module.pre_ops[
                '0'].op.binary_filter_pruning_mask

        # In case of group convs we can't prune by output filters
        if node_module.groups != 1:
            if node_module.weight.size(1) == 1:
                # Depthwise case
                is_depthwise = True
                output_mask = input_masks[0]
            else:
                accept_pruned_input = False
                output_mask = None

        nx_node['input_masks'] = input_masks
        nx_node['output_mask'] = output_mask
        nx_node['accept_pruned_input'] = accept_pruned_input
        nx_node['is_depthwise'] = is_depthwise
Example #4
0
    def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                     nx_graph: nx.DiGraph):
        output_mask = nx_node['output_mask']
        if output_mask is None:
            return

        bool_mask = torch.tensor(output_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(bool_mask))

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        old_num_clannels = int(node_module.weight.size(1))

        in_channels = node_module.weight.size(0)
        broadcasted_mask = bool_mask.repeat(in_channels).view(
            in_channels, bool_mask.size(0))
        new_weight_shape = list(node_module.weight.shape)
        new_weight_shape[1] = new_num_channels

        node_module.out_channels = new_num_channels
        node_module.weight = torch.nn.Parameter(
            node_module.weight[broadcasted_mask].view(new_weight_shape))

        if node_module.bias is not None:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels,
                          node_module.out_channels))
Example #5
0
    def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                         nx_graph: nx.DiGraph):
        output_mask = None
        accept_pruned_input = True
        input_masks = get_input_masks(nx_node, nx_graph)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        if node_module.pre_ops:
            output_mask = node_module.pre_ops[
                '0'].op.binary_filter_pruning_mask

        nx_node['input_masks'] = input_masks
        nx_node['output_mask'] = output_mask
        nx_node['accept_pruned_input'] = accept_pruned_input
Example #6
0
    def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.in_channels = int(torch.sum(bool_mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels,
                          node_module.in_channels))
Example #7
0
    def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                     nx_graph: nx.DiGraph):
        mask = nx_node['output_mask']
        if mask is None:
            return

        bool_mask = torch.tensor(mask, dtype=torch.bool)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.out_channels = int(torch.sum(mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        if node_module.bias is not None and not nx_node['is_depthwise']:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned Convolution {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels,
                          node_module.out_channels))
Example #8
0
    def input_prune(self, model: NNCFNetwork, nx_node: dict, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        if isinstance(node_module, tuple(NNCF_WRAPPED_USER_MODULES_DICT)):
            assert node_module.target_weight_dim_for_compression == 0,\
                "Implemented only for target_weight_dim_for_compression == 0"
            old_num_clannels = int(node_module.weight.size(0))
            new_num_channels = int(torch.sum(input_mask))
            node_module.weight = torch.nn.Parameter(
                node_module.weight[bool_mask])
            node_module.n_channels = new_num_channels

            nncf_logger.info(
                'Pruned Elementwise {} by input mask. Old num features: {}, new num features:'
                ' {}.'.format(nx_node['key'], old_num_clannels,
                              new_num_channels))