예제 #1
0
def check_correct_nncf_modules_replacement(model: NNCFNetwork, compressed_model: NNCFNetwork) \
    -> Tuple[Dict[Scope, Module], Dict[Scope, Module]]:
    """
    Check that all convolutions in model was replaced by NNCF convolution.
    :param model: original model
    :param compressed_model: compressed model
    :return: list of all convolutions in  original model and list of all NNCF convolutions from compressed model
    """
    NNCF_MODULES_REVERSED_MAP = {value: key for key, value in NNCF_MODULES_MAP.items()}
    original_modules = get_all_modules_by_type(model, list(NNCF_MODULES_MAP.values()))
    nncf_modules = get_all_modules_by_type(compressed_model.get_nncf_wrapped_model(),
                                           list(NNCF_MODULES_MAP.keys()))
    assert len(original_modules) == len(nncf_modules)
    print(original_modules, nncf_modules)
    for scope in original_modules.keys():
        sparse_scope = deepcopy(scope)
        elt = sparse_scope.pop()  # type: ScopeElement
        elt.calling_module_class_name = NNCF_MODULES_REVERSED_MAP[elt.calling_module_class_name]
        sparse_scope.push(elt)
        print(sparse_scope, nncf_modules)
        assert sparse_scope in nncf_modules
    return original_modules, nncf_modules
예제 #2
0
    def _prune_weights(self, target_model: NNCFNetwork):
        grops_of_modules_to_prune = self._create_pruning_groups(target_model)

        device = next(target_model.parameters()).device
        insertion_commands = []
        self.pruned_module_groups_info = Clusterization('module_scope')

        for i, group in enumerate(grops_of_modules_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.nodes:
                module_scope, module = node.module_scope, node.module
                # Check that we need to prune weights in this op
                assert self._is_pruned_module(module)

                module_scope_str = str(module_scope)

                nncf_logger.info("Adding Weight Pruner in scope: {}".format(module_scope_str))
                operation = self.create_weight_pruning_operation(module)
                hook = UpdateWeight(operation).to(device)
                insertion_commands.append(
                    InsertionCommand(
                        InsertionPoint(InsertionType.NNCF_MODULE_PRE_OP,
                                       module_scope=module_scope),
                        hook,
                        OperationPriority.PRUNING_PRIORITY
                    )
                )

                related_modules = {}
                if self.prune_batch_norms:
                    bn_module, bn_scope = get_bn_for_module_scope(target_model, module_scope)
                    related_modules[PrunedModuleInfo.BN_MODULE_NAME] = BatchNormInfo(module_scope,
                                                                                     bn_module, bn_scope)

                minfo = PrunedModuleInfo(module_scope, module, hook.operand, related_modules, node.id)
                group_minfos.append(minfo)
            cluster = NodesCluster(i, group_minfos, [n.id for n in group.nodes])
            self.pruned_module_groups_info.add_cluster(cluster)
        return insertion_commands
예제 #3
0
def test_disable_shape_matching():
    class MatMulModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.dummy_param = torch.nn.Parameter(torch.ones([1]))

        def forward(self, inputs):
            half1, half2 = torch.chunk(inputs, 2, dim=2)
            return torch.bmm(half1, half2.transpose(1, 2))

    model = MatMulModel()

    input_shape_1 = (3, 32, 32)
    input_shape_2 = (4, 64, 64)

    qnet_no_shape = NNCFNetwork(
        deepcopy(model),
        input_infos=[
            ModelInputInfo(input_shape_1),
        ],
        scopes_without_shape_matching=['MatMulModel'])  # type: NNCFNetwork
    _ = qnet_no_shape(torch.zeros(*input_shape_1))
    graph_1 = deepcopy(qnet_no_shape.get_graph())

    _ = qnet_no_shape(torch.zeros(*input_shape_2))
    graph_2 = deepcopy(qnet_no_shape.get_graph())

    keys_1 = list(graph_1.get_all_node_keys())
    keys_2 = list(graph_2.get_all_node_keys())
    assert len(keys_1) == 2  # 1 input node + 1 operation node
    assert keys_1 == keys_2

    qnet = NNCFNetwork(model, input_infos=[
        ModelInputInfo(input_shape_1),
    ])  # type: NNCFNetwork
    _ = qnet(torch.zeros(*input_shape_1))
    _ = qnet(torch.zeros(*input_shape_2))
    # The second forward run should have led to an increase in registered node counts
    # since disable_shape_matching was False and the network was run with a different
    # shape of input tensor
    assert qnet.get_graph().get_nodes_count() > graph_1.get_nodes_count()
예제 #4
0
    def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                     nx_graph: nx.DiGraph):
        mask = nx_node['output_mask']
        if mask is None:
            return

        bool_mask = torch.tensor(mask, dtype=torch.bool)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.out_channels = int(torch.sum(mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        if node_module.bias is not None and not nx_node['is_depthwise']:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned Convolution {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels,
                          node_module.out_channels))
예제 #5
0
    def input_prune(self, model: NNCFNetwork, nx_node: dict, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        if isinstance(node_module, tuple(NNCF_WRAPPED_USER_MODULES_DICT)):
            assert node_module.target_weight_dim_for_compression == 0,\
                "Implemented only for target_weight_dim_for_compression == 0"
            old_num_clannels = int(node_module.weight.size(0))
            new_num_channels = int(torch.sum(input_mask))
            node_module.weight = torch.nn.Parameter(
                node_module.weight[bool_mask])
            node_module.n_channels = new_num_channels

            nncf_logger.info(
                'Pruned Elementwise {} by input mask. Old num features: {}, new num features:'
                ' {}.'.format(nx_node['key'], old_num_clannels,
                              new_num_channels))
예제 #6
0
def test_find_node_in_nx_graph_by_scope():
    model = TwoConvTestModel()
    nncf_model = NNCFNetwork(deepcopy(model), input_infos=[ModelInputInfo([1, 1, 4, 4])])  # type: NNCFNetwork
    nncf_graph = nncf_model.get_original_graph()

    # Valid scopes should be successfully found
    valid_nncf_modules = nncf_model.get_nncf_modules()
    nodes_list = list(nncf_graph._nx_graph.nodes)
    for module_scope, _ in valid_nncf_modules.items():
        graph_node = nncf_graph.find_node_in_nx_graph_by_scope(module_scope)
        assert graph_node is not None
        assert isinstance(graph_node, dict)
        assert graph_node['key'] in nodes_list

    fake_model = BasicConvTestModel()
    fake_nncf_model = NNCFNetwork(deepcopy(fake_model), input_infos=[ModelInputInfo([1, 1, 4, 4])])

    # Not valid scopes shouldn't be found
    fake_nncf_modules = fake_nncf_model.get_nncf_modules()
    for module_scope, _ in fake_nncf_modules.items():
        graph_node = nncf_graph.find_node_in_nx_graph_by_scope(module_scope)
        assert graph_node is None
예제 #7
0
파일: utils.py 프로젝트: xiaming9880/nncf
def find_next_nodes_not_of_types(model: NNCFNetwork, nncf_node: NNCFNode, types: List[str]) -> List[NNCFNode]:
    """
    Traverse nodes in the graph from nncf node to find first nodes that aren't of type from types list.
    First nodes with some condition mean nodes:
    - for which this condition is true
    - reachable from nncf_node such that on the path from nncf_node to this nodes there are no other nodes with
    fulfilled condition
    :param model: model to worh with
    :param nncf_node: NNCFNode to start search
    :param types: list of types
    :return: list of next nodes for nncf_node of type not from types list
    """
    graph = model.get_original_graph()
    visited = {node_id: False for node_id in graph.get_all_node_idxs()}
    partial_traverse_function = partial(traverse_function, nncf_graph=graph, type_check_fn=lambda x: x not in types,
                                        visited=visited)
    nncf_nodes = [nncf_node]
    if nncf_node.op_exec_context.operator_name not in types:
        nncf_nodes = graph.get_next_nodes(nncf_node)

    next_nodes = []
    for node in nncf_nodes:
        next_nodes.extend(graph.traverse_graph(node, partial_traverse_function))
    return next_nodes
예제 #8
0
 def apply_to(self, target_model: NNCFNetwork) -> NNCFNetwork:
     insertion_commands = self._prune_weights(target_model)
     for command in insertion_commands:
         target_model.register_insertion_command(command)
     target_model.register_algorithm(self)
     return target_model
예제 #9
0
    def test_operator_metatype_marking(self):
        from nncf.dynamic_graph.operator_metatypes import Conv2dMetatype, BatchNormMetatype, RELUMetatype, \
            MaxPool2dMetatype, \
            ConvTranspose2dMetatype, DepthwiseConv2dSubtype, AddMetatype, AvgPool2dMetatype, LinearMetatype
        ref_scope_vs_metatype_dict = {
            "/" + MODEL_INPUT_OP_NAME + "_0":
            NoopMetatype,
            "ModelForMetatypeTesting/NNCFConv2d[conv_regular]/conv2d_0":
            Conv2dMetatype,
            "ModelForMetatypeTesting/BatchNorm2d[bn]/batch_norm_0":
            BatchNormMetatype,
            "ModelForMetatypeTesting/RELU_0":
            RELUMetatype,
            "ModelForMetatypeTesting/MaxPool2d[max_pool2d]/max_pool2d_0":
            MaxPool2dMetatype,
            "ModelForMetatypeTesting/NNCFConvTranspose2d[conv_transpose]/conv_transpose2d_0":
            ConvTranspose2dMetatype,
            "ModelForMetatypeTesting/NNCFConv2d[conv_depthwise]/conv2d_0":
            DepthwiseConv2dSubtype,
            "ModelForMetatypeTesting/__iadd___0":
            AddMetatype,
            "ModelForMetatypeTesting/AdaptiveAvgPool2d[adaptive_avg_pool]/adaptive_avg_pool2d_0":
            AvgPool2dMetatype,
            "ModelForMetatypeTesting/NNCFLinear[linear]/linear_0":
            LinearMetatype
        }

        class ModelForMetatypeTesting(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv_regular = torch.nn.Conv2d(in_channels=3,
                                                    out_channels=16,
                                                    kernel_size=3)
                self.bn = torch.nn.BatchNorm2d(num_features=16)
                self.max_pool2d = torch.nn.MaxPool2d(kernel_size=2)
                self.conv_transpose = torch.nn.ConvTranspose2d(in_channels=16,
                                                               out_channels=8,
                                                               kernel_size=3)
                self.conv_depthwise = torch.nn.Conv2d(in_channels=8,
                                                      out_channels=8,
                                                      kernel_size=5,
                                                      groups=8)
                self.adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d(
                    output_size=1)
                self.linear = torch.nn.Linear(in_features=8, out_features=1)

            def forward(self, input_):
                x = self.conv_regular(input_)
                x = self.bn(x)
                x = torch.nn.functional.relu(x)
                x.transpose_(2, 3)
                x = self.max_pool2d(x)
                x = self.conv_transpose(x)
                x = self.conv_depthwise(x)
                x += torch.ones_like(x)
                x = self.adaptive_avg_pool(x)
                x = self.linear(x.flatten())
                return x

        model = ModelForMetatypeTesting()
        nncf_network = NNCFNetwork(model, [ModelInputInfo([1, 3, 300, 300])])
        ip_graph = nncf_network.get_insertion_point_graph()

        for node in ip_graph.nodes().values():
            if node[InsertionPointGraph.
                    NODE_TYPE_NODE_ATTR] == InsertionPointGraphNodeType.OPERATOR:
                nncf_node_ref = node[
                    InsertionPointGraph.REGULAR_NODE_REF_NODE_ATTR]
                scope_str = str(nncf_node_ref[
                    NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].input_agnostic)
                assert scope_str in ref_scope_vs_metatype_dict
                ref_metatype = ref_scope_vs_metatype_dict[scope_str]
                assert node[InsertionPointGraph.
                            OPERATOR_METATYPE_NODE_ATTR] == ref_metatype
예제 #10
0
 def setup(self):
     self.compressed_model = NNCFNetwork(
         InsertionPointTestModel(),
         [ModelInputInfo([1, 1, 10, 10])])  # type: NNCFNetwork
예제 #11
0
class TestInsertionCommands:
    @pytest.fixture()
    def setup(self):
        self.compressed_model = NNCFNetwork(
            InsertionPointTestModel(),
            [ModelInputInfo([1, 1, 10, 10])])  # type: NNCFNetwork

    conv1_module_scope = Scope.from_str(
        'InsertionPointTestModel/NNCFConv2d[conv1]')
    conv1_module_context = InputAgnosticOperationExecutionContext(
        '', conv1_module_scope, 0)
    point_for_conv1_weights = InsertionPoint(
        ia_op_exec_context=conv1_module_context,
        insertion_type=InsertionType.NNCF_MODULE_PRE_OP)
    point_for_conv1_inputs = InsertionPoint(
        ia_op_exec_context=conv1_module_context,
        insertion_type=InsertionType.NNCF_MODULE_PRE_OP)
    point_for_conv1_activations = InsertionPoint(
        ia_op_exec_context=conv1_module_context,
        insertion_type=InsertionType.NNCF_MODULE_POST_OP)

    conv2_module_scope = Scope.from_str(
        'InsertionPointTestModel/NNCFConv2d[conv2]')
    conv2_module_context = InputAgnosticOperationExecutionContext(
        '', conv2_module_scope, 0)
    point_for_conv2_weights = InsertionPoint(
        ia_op_exec_context=conv2_module_context,
        insertion_type=InsertionType.NNCF_MODULE_PRE_OP)
    point_for_conv2_inputs = InsertionPoint(
        ia_op_exec_context=conv2_module_context,
        insertion_type=InsertionType.NNCF_MODULE_PRE_OP)
    point_for_conv2_activations = InsertionPoint(
        ia_op_exec_context=conv2_module_context,
        insertion_type=InsertionType.NNCF_MODULE_POST_OP)

    linear_op_scope = Scope.from_str('InsertionPointTestModel/linear_0')
    linear_op_context = InputAgnosticOperationExecutionContext(
        'linear', linear_op_scope, 0)
    point_for_linear_weight_input = InsertionPoint(
        ia_op_exec_context=linear_op_context,
        insertion_type=InsertionType.OPERATOR_PRE_HOOK)
    point_for_linear_activation = InsertionPoint(
        ia_op_exec_context=linear_op_context,
        insertion_type=InsertionType.OPERATOR_POST_HOOK)

    relu_op_scope = Scope.from_str('InsertionPointTestModel/ReLU[relu]/relu')
    relu_op_context = InputAgnosticOperationExecutionContext(
        'relu', relu_op_scope, 0)
    point_for_relu_inputs = InsertionPoint(
        ia_op_exec_context=relu_op_context,
        insertion_type=InsertionType.OPERATOR_PRE_HOOK)
    point_for_relu_activations = InsertionPoint(
        ia_op_exec_context=relu_op_context,
        insertion_type=InsertionType.OPERATOR_POST_HOOK)

    available_points = [
        point_for_conv1_weights, point_for_conv2_weights,
        point_for_conv1_inputs, point_for_conv2_inputs,
        point_for_conv1_activations, point_for_conv2_activations,
        point_for_linear_activation, point_for_linear_weight_input,
        point_for_relu_activations, point_for_relu_inputs
    ]

    @pytest.mark.parametrize("insertion_point", available_points)
    def test_single_insertions(self, setup, insertion_point):
        if insertion_point.insertion_type in [
                InsertionType.OPERATOR_PRE_HOOK,
                InsertionType.OPERATOR_POST_HOOK
        ]:
            hook = lambda x: x
        else:
            hook = BaseOp(lambda x: x)

        command = InsertionCommand(insertion_point, hook)
        self.compressed_model.register_insertion_command(command)
        self.compressed_model.commit_compression_changes()

        #pylint:disable=protected-access
        if insertion_point.insertion_type == InsertionType.OPERATOR_PRE_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            assert ctx._pre_hooks[
                command.insertion_point.ia_op_exec_context][0] is hook
        if insertion_point.insertion_type == InsertionType.OPERATOR_POST_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            assert ctx._post_hooks[
                command.insertion_point.ia_op_exec_context][0] is hook
        if insertion_point.insertion_type == InsertionType.NNCF_MODULE_PRE_OP:
            module = self.compressed_model.get_module_by_scope(
                command.insertion_point.ia_op_exec_context.scope_in_model)
            assert module.pre_ops["0"] is hook

        if insertion_point.insertion_type == InsertionType.NNCF_MODULE_POST_OP:
            module = self.compressed_model.get_module_by_scope(
                command.insertion_point.ia_op_exec_context.scope_in_model)
            assert module.post_ops["0"] is hook

    priority_types = ["same", "different"]
    insertion_types = InsertionType
    priority_test_cases = list(
        itertools.product(priority_types, insertion_types))

    @staticmethod
    def check_order(iterable1: List, iterable2: List, ordering: List):
        for idx, order in enumerate(ordering):
            assert iterable1[idx] is iterable2[order]

    # pylint:disable=undefined-variable
    @pytest.mark.parametrize(
        "case",
        priority_test_cases,
        ids=[x[1].name + '-' + x[0] for x in priority_test_cases])
    def test_priority(self, case, setup):
        #pylint:disable=too-many-branches
        priority_type = case[0]
        insertion_type = case[1]
        if insertion_type in [
                InsertionType.NNCF_MODULE_PRE_OP,
                InsertionType.NNCF_MODULE_POST_OP
        ]:
            hook1 = BaseOp(lambda x: x)
            hook2 = BaseOp(lambda x: 2 * x)
            hook3 = BaseOp(lambda x: 3 * x)
        else:
            hook1 = lambda x: x
            hook2 = lambda x: 2 * x
            hook3 = lambda x: 3 * x

        if insertion_type == InsertionType.NNCF_MODULE_PRE_OP:
            point = self.point_for_conv2_weights
        elif insertion_type == InsertionType.NNCF_MODULE_POST_OP:
            point = self.point_for_conv1_activations
        elif insertion_type == InsertionType.OPERATOR_PRE_HOOK:
            point = self.point_for_linear_weight_input
        elif insertion_type == InsertionType.OPERATOR_POST_HOOK:
            point = self.point_for_relu_activations

        if priority_type == "same":
            # Same-priority commands will be executed in registration order
            command1 = InsertionCommand(point, hook1,
                                        OperationPriority.DEFAULT_PRIORITY)
            command2 = InsertionCommand(point, hook2,
                                        OperationPriority.DEFAULT_PRIORITY)
            command3 = InsertionCommand(point, hook3,
                                        OperationPriority.DEFAULT_PRIORITY)
        else:
            # Prioritized commands will be executed in ascending priority order
            command1 = InsertionCommand(
                point, hook1, OperationPriority.SPARSIFICATION_PRIORITY)
            command2 = InsertionCommand(
                point, hook2, OperationPriority.QUANTIZATION_PRIORITY)
            command3 = InsertionCommand(point, hook3,
                                        OperationPriority.DEFAULT_PRIORITY)

        self.compressed_model.register_insertion_command(command1)
        self.compressed_model.register_insertion_command(command2)
        self.compressed_model.register_insertion_command(command3)
        self.compressed_model.commit_compression_changes()

        hook_list = [hook1, hook2, hook3]

        if priority_type == "same":
            order = [0, 1, 2]
        elif priority_type == "different":
            order = [2, 0, 1]

        #pylint:disable=protected-access
        if insertion_type == InsertionType.OPERATOR_PRE_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            self.check_order(ctx._pre_hooks[point.ia_op_exec_context],
                             hook_list, order)
        if insertion_type == InsertionType.OPERATOR_POST_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            self.check_order(ctx._post_hooks[point.ia_op_exec_context],
                             hook_list, order)

        if insertion_type == InsertionType.NNCF_MODULE_PRE_OP:
            module = self.compressed_model.get_module_by_scope(
                point.ia_op_exec_context.scope_in_model)
            # Works because Pytorch ModuleDict is ordered
            self.check_order(list(module.pre_ops.values()), hook_list, order)

        if insertion_type == InsertionType.NNCF_MODULE_POST_OP:
            module = self.compressed_model.get_module_by_scope(
                point.ia_op_exec_context.scope_in_model)
            # Works because Pytorch ModuleDict is ordered
            self.check_order(list(module.post_ops.values()), hook_list, order)
예제 #12
0
def create_compressed_model(model: Module, config: NNCFConfig,
                            resuming_state_dict: dict = None,
                            dummy_forward_fn: Callable[[Module], Any] = None,
                            wrap_inputs_fn: Callable[[Tuple, Dict], Tuple[Tuple, Dict]] = None,
                            dump_graphs=True,) \
    -> Tuple[CompressionAlgorithmController, NNCFNetwork]:
    """
    The main function used to produce a model ready for compression fine-tuning from an original PyTorch
    model and a configuration object.
    dummy_forward_fn
    :param model: The original model. Should have its parameters already loaded from a checkpoint or another
    source.
    :param config: A configuration object used to determine the exact compression modifications to be applied
    to the model
    :param resuming_state_dict: A PyTorch state dict object to load (strictly) into the compressed model after
    building.
    :param dummy_forward_fn: if supplied, will be used instead of a *forward* function call to build
    the internal graph representation via tracing. Specifying this is useful when the original training pipeline
    has special formats of data loader output or has additional *forward* arguments other than input tensors.
    Otherwise, the *forward* call of the model during graph tracing will be made with mock tensors according
    to the shape specified in the config object.
    :param wrap_inputs_fn: if supplied, will be used on the module's input arguments during a regular, non-dummy
    forward call before passing the inputs to the underlying compressed model. This is required if the model's input
    tensors that are important for compression are not supplied as arguments to the model's forward call directly, but
    instead are located in a container (such as list), and the model receives the container as an argument.
    wrap_inputs_fn should take as input two arguments - the tuple of positional arguments to the underlying
    model's forward call, and a dict of keyword arguments to the same. The function should wrap each tensor among the
    supplied model's args and kwargs that is important for compression (e.g. quantization) with an nncf.nncf_model_input
    function, which is a no-operation function and marks the tensors as inputs to be traced by NNCF in the internal
    graph representation. Output is the tuple of (args, kwargs), where args and kwargs are the same as were supplied in
    input, but each tensor in the original input.
    :param dump_graphs: Whether or not should also dump the internal graph representation of the
    original and compressed models in the .dot format into the log directory.
    :return: A controller for the compression algorithm (or algorithms, in which case the controller
    is an instance of CompositeCompressionController) and the model ready for compression parameter training wrapped
    as an object of NNCFNetwork."""

    # Compress model that will be deployed for the inference on target device. No need to compress parts of the
    # model that are used on training stage only (e.g. AuxLogits of Inception-v3 model) or unused modules with weights.
    # As a consequence, no need to care about spoiling BN statistics, as there're disabled in eval mode.
    model.eval()

    if dump_graphs:
        if dummy_forward_fn is None:
            input_info_list = create_input_infos(config)
            graph_builder = GraphBuilder(
                custom_forward_fn=create_dummy_forward_fn(
                    input_info_list, with_input_tracing=True))
        else:
            graph_builder = GraphBuilder(custom_forward_fn=dummy_forward_fn)

        if is_main_process():
            graph = graph_builder.build_graph(model)
            graph.visualize_graph(
                osp.join(config.get("log_dir", "."), "original_graph.dot"))

    set_debug_log_dir(config.get("log_dir", "."))

    input_info_list = create_input_infos(config)
    scopes_without_shape_matching = config.get('scopes_without_shape_matching',
                                               [])
    ignored_scopes = config.get('ignored_scopes')
    target_scopes = config.get('target_scopes')

    compressed_model = NNCFNetwork(
        model,
        input_infos=input_info_list,
        dummy_forward_fn=dummy_forward_fn,
        wrap_inputs_fn=wrap_inputs_fn,
        ignored_scopes=ignored_scopes,
        target_scopes=target_scopes,
        scopes_without_shape_matching=scopes_without_shape_matching)

    should_init = resuming_state_dict is None
    compression_algo_builder_list = create_compression_algorithm_builders(
        config, should_init=should_init)

    for builder in compression_algo_builder_list:
        compressed_model = builder.apply_to(compressed_model)
    compression_ctrl = compressed_model.commit_compression_changes()

    try:
        if resuming_state_dict is not None:
            load_state(compressed_model, resuming_state_dict, is_resume=True)
    finally:
        if dump_graphs and is_main_process() and compression_algo_builder_list:
            if dummy_forward_fn is None:
                compressed_graph_builder = GraphBuilder(
                    custom_forward_fn=create_dummy_forward_fn(
                        input_info_list, with_input_tracing=False))
            else:
                compressed_graph_builder = GraphBuilder(
                    custom_forward_fn=dummy_forward_fn)

            graph = compressed_graph_builder.build_graph(
                compressed_model, compressed_model.get_tracing_context())
            graph.visualize_graph(
                osp.join(config.get("log_dir", "."), "compressed_graph.dot"))
    return compression_ctrl, compressed_model
    def _prune_weights(self, target_model: NNCFNetwork):
        device = next(target_model.parameters()).device
        modules_to_prune = target_model.get_nncf_modules()
        insertion_commands = []

        input_non_pruned_modules = get_first_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])
        output_non_pruned_modules = get_last_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])

        for module_scope, module in modules_to_prune.items():
            # Check that we need to prune weights in this op
            if not self._is_pruned_module(module):
                continue

            module_scope_str = str(module_scope)
            if not self._should_consider_scope(module_scope_str):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {}".format(
                        module_scope_str))
                continue

            if not self.prune_first and module in input_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the first convolutions".format(
                        module_scope_str))
                continue
            if not self.prune_last and module in output_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the last convolutions".format(
                        module_scope_str))
                continue

            if not self.prune_downsample_convs and is_conv_with_downsampling(
                    module):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is convolution with downsample".format(
                        module_scope_str))
                continue

            nncf_logger.info(
                "Adding Weight Pruner in scope: {}".format(module_scope_str))
            operation = self.create_weight_pruning_operation(module)
            hook = UpdateWeight(operation).to(device)
            insertion_commands.append(
                InsertionCommand(
                    InsertionPoint(
                        InputAgnosticOperationExecutionContext(
                            "", module_scope, 0),
                        InsertionType.NNCF_MODULE_PRE_OP), hook,
                    OperationPriority.PRUNING_PRIORITY))

            related_modules = {}
            if self.prune_batch_norms:
                related_modules['bn_module'] = get_bn_for_module_scope(
                    target_model, module_scope)

            self._pruned_module_info.append(
                PrunedModuleInfo(module_scope_str, module, hook.operand,
                                 related_modules))

        return insertion_commands
def check_model_graph(compressed_model: NNCFNetwork, ref_dot_file_name: str, ref_dot_file_directory: str):
    compressed_model.to('cuda')
    compressed_model.do_dummy_forward()
    compressed_model.do_dummy_forward()
    check_graph(compressed_model.get_graph(), ref_dot_file_name, ref_dot_file_directory)
예제 #15
0
    def _create_pruning_groups(self, target_model: NNCFNetwork):
        """
        This function groups ALL modules with pruning types to groups that should be pruned together.
        1. Create clusters for special ops (eltwises) that should be pruned together
        2. Create groups of nodes that should be pruned together (taking into account clusters of special ops)
        3. Add remaining single nodes
        4. Unite clusters for Conv + Depthwise conv (should be pruned together too)
        5. Checks for groups (all nodes in group can prune or all group can't be pruned)
        Return groups of modules that should be pruned together.
        :param target_model: model to work with
        :return: clusterisation of pruned nodes
        """
        graph = target_model.get_original_graph()
        pruned_types = self.get_op_types_of_pruned_modules()
        all_modules_to_prune = target_model.get_nncf_modules_by_module_names(self.compressed_nncf_module_names)
        all_nodes_to_prune = graph.get_nodes_by_types(pruned_types)  # NNCFNodes here
        assert len(all_nodes_to_prune) <= len(all_modules_to_prune)

        # 1. Clusters for special ops
        special_ops_types = self.get_types_of_grouping_ops()
        identity_like_types = IdentityMaskForwardOps.get_all_op_aliases()
        special_ops_clusterization = cluster_special_ops(target_model, special_ops_types,
                                                         identity_like_types)

        pruned_nodes_clusterization = Clusterization("id")

        # 2. Clusters for nodes that should be pruned together (taking into account clusters for special ops)
        for i, cluster in enumerate(special_ops_clusterization.get_all_clusters()):
            all_pruned_inputs = []
            pruned_inputs_idxs = set()

            for node in cluster.nodes:
                sources = get_sources_of_node(node, graph, pruned_types)
                for source_node in sources:
                    source_scope = source_node.op_exec_context.scope_in_model
                    source_module = target_model.get_module_by_scope(source_scope)
                    source_node_info = NodeInfo(source_node, source_module, source_scope)

                    if source_node.node_id not in pruned_inputs_idxs:
                        all_pruned_inputs.append(source_node_info)
                        pruned_inputs_idxs.add(source_node.node_id)
            if all_pruned_inputs:
                cluster = NodesCluster(i, list(all_pruned_inputs), [n.id for n in all_pruned_inputs])
                pruned_nodes_clusterization.add_cluster(cluster)

        last_cluster_idx = len(special_ops_clusterization.get_all_clusters())

        # 3. Add remaining single nodes as separate clusters
        for node in all_nodes_to_prune:
            if not pruned_nodes_clusterization.is_node_in_clusterization(node.node_id):
                scope = node.op_exec_context.scope_in_model
                module = target_model.get_module_by_scope(scope)
                node_info = NodeInfo(node, module, scope)

                cluster = NodesCluster(last_cluster_idx, [node_info], [node.node_id])
                pruned_nodes_clusterization.add_cluster(cluster)

                last_cluster_idx += 1

        # 4. Merge clusters for Conv + Depthwise conv (should be pruned together too)
        for node in all_nodes_to_prune:
            scope = node.op_exec_context.scope_in_model
            module = target_model.get_module_by_scope(scope)
            cluster_id = pruned_nodes_clusterization.get_cluster_by_node_id(node.node_id).id

            if is_depthwise_conv(module):
                previous_conv = get_previous_conv(target_model, module, scope)
                if previous_conv:
                    previous_conv_cluster_id = pruned_nodes_clusterization.get_cluster_by_node_id(
                        previous_conv.node_id).id
                    pruned_nodes_clusterization.merge_clusters(cluster_id, previous_conv_cluster_id)

        # 5. Checks for groups (all nodes in group can be pruned or all group can't be pruned).
        model_analyser = ModelAnalyzer(target_model)
        can_prune_analysis = model_analyser.analyse_model_before_pruning()
        self._check_pruning_groups(target_model, pruned_nodes_clusterization, can_prune_analysis)
        return pruned_nodes_clusterization
예제 #16
0
    def _prune_weights(self, target_model: NNCFNetwork):
        device = next(target_model.parameters()).device
        modules_to_prune = target_model.get_nncf_modules()
        insertion_commands = []
        bn_for_depthwise = {}

        input_non_pruned_modules = get_first_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])
        output_non_pruned_modules = get_last_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])

        for module_scope, module in modules_to_prune.items():
            # Check that we need to prune weights in this op
            if not self._is_pruned_module(module):
                continue

            module_scope_str = str(module_scope)
            if self.ignore_frozen_layers and not module.weight.requires_grad:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " the layer appears to be frozen (requires_grad=False)".
                    format(module_scope_str))
                continue

            if not self._should_consider_scope(module_scope_str):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {}".format(
                        module_scope_str))
                continue

            if not self.prune_first and module in input_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the first convolutions".format(
                        module_scope_str))
                continue
            if not self.prune_last and module in output_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the last convolutions".format(
                        module_scope_str))
                continue

            if is_grouped_conv(module):
                if is_depthwise_conv(module):
                    previous_conv = get_previous_conv(target_model, module,
                                                      module_scope)
                    if previous_conv:
                        depthwise_bn = get_bn_for_module_scope(
                            target_model, module_scope)
                        bn_for_depthwise[str(previous_conv.op_exec_context.
                                             scope_in_model)] = depthwise_bn

                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is grouped convolution".format(
                        module_scope_str))
                continue

            if not self.prune_downsample_convs and is_conv_with_downsampling(
                    module):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is convolution with downsample".format(
                        module_scope_str))
                continue

            nncf_logger.info(
                "Adding Weight Pruner in scope: {}".format(module_scope_str))
            operation = self.create_weight_pruning_operation(module)
            hook = UpdateWeight(operation).to(device)
            insertion_commands.append(
                InsertionCommand(
                    InsertionPoint(
                        InputAgnosticOperationExecutionContext(
                            "", module_scope, 0),
                        InsertionType.NNCF_MODULE_PRE_OP), hook,
                    OperationPriority.PRUNING_PRIORITY))

            related_modules = {}
            if self.prune_batch_norms:
                related_modules[
                    PrunedModuleInfo.BN_MODULE_NAME] = get_bn_for_module_scope(
                        target_model, module_scope)

            self._pruned_module_info.append(
                PrunedModuleInfo(module_scope_str, module, hook.operand,
                                 related_modules))

        if self.prune_batch_norms:
            self.update_minfo_with_depthwise_bn(bn_for_depthwise)

        return insertion_commands
예제 #17
0
def create_compressed_model(model: Module, config: NNCFConfig,
                            resuming_state_dict: dict = None,
                            dummy_forward_fn: Callable[[Module], Any] = None,
                            dump_graphs=True,) \
    -> Tuple[CompressionAlgorithmController, NNCFNetwork]:
    """
    The main function used to produce a model ready for compression fine-tuning from an original PyTorch
    model and a configuration object.
    dummy_forward_fn
    :param model: The original model. Should have its parameters already loaded from a checkpoint or another
    source.
    :param config: A configuration object used to determine the exact compression modifications to be applied
    to the model
    :param resuming_state_dict: A PyTorch state dict object to load (strictly) into the compressed model after
    building.
    :param dummy_forward_fn: will be used instead of a *forward* function call to build
    the internal graph representation via tracing. Specifying this is useful when the original training pipeline
    has special formats of data loader output or has additional *forward* arguments other than input tensors.
    Otherwise, the *forward* call of the model during graph tracing will be made with mock tensors according
    to the shape specified in the config object.
    :param dump_graphs: Whether or not should also dump the internal graph representation of the
    original and compressed models in the .dot format into the log directory.
    :return: A controller for the compression algorithm (or algorithms, in which case the controller
    is an instance of CompositeCompressionController) and the model ready for compression parameter training wrapped
    as an object of NNCFNetwork."""

    if dump_graphs:
        if dummy_forward_fn is None:
            input_info_list = create_input_infos(config)
            graph_builder = GraphBuilder(
                custom_forward_fn=create_dummy_forward_fn(
                    input_info_list, with_input_tracing=True))
        else:
            graph_builder = GraphBuilder(custom_forward_fn=dummy_forward_fn)

        if is_main_process():
            graph = graph_builder.build_graph(model)
            graph.dump_graph(osp.join(config.get("log_dir", "."),
                                      "original_graph.dot"),
                             extended=True)

    if is_debug():
        set_debug_log_dir(config.get("log_dir", "."))

    input_info_list = create_input_infos(config)
    scopes_without_shape_matching = config.get('scopes_without_shape_matching',
                                               [])
    ignored_scopes = config.get('ignored_scopes')
    target_scopes = config.get('target_scopes')

    compressed_model = NNCFNetwork(
        model,
        input_infos=input_info_list,
        dummy_forward_fn=dummy_forward_fn,
        ignored_scopes=ignored_scopes,
        target_scopes=target_scopes,
        scopes_without_shape_matching=scopes_without_shape_matching)

    should_init = resuming_state_dict is None
    compression_algo_builder_list = create_compression_algorithm_builders(
        config, should_init=should_init)

    for builder in compression_algo_builder_list:
        compressed_model = builder.apply_to(compressed_model)
    compression_ctrl = compressed_model.commit_compression_changes()

    if dump_graphs and is_main_process() and compression_algo_builder_list:
        if dummy_forward_fn is None:
            compressed_graph_builder = GraphBuilder(
                custom_forward_fn=create_dummy_forward_fn(
                    input_info_list, with_input_tracing=False))
        else:
            compressed_graph_builder = GraphBuilder(
                custom_forward_fn=dummy_forward_fn)

        graph = compressed_graph_builder.build_graph(
            compressed_model, compressed_model.get_tracing_context())
        graph.dump_graph(osp.join(config.get("log_dir", "."),
                                  "compressed_graph.dot"),
                         extended=True)

    if resuming_state_dict is not None:
        load_state(compressed_model, resuming_state_dict, is_resume=True)

    return compression_ctrl, compressed_model
예제 #18
0
def test_check_correct_modules_replacement():
    model = TwoConvTestModel()
    nncf_model = NNCFNetwork(TwoConvTestModel(), input_infos=[ModelInputInfo([1, 1, 4, 4])])  # type: NNCFNetwork

    _, nncf_modules = check_correct_nncf_modules_replacement(model, nncf_model)
    assert set(nncf_modules) == set(nncf_model.get_nncf_modules())