def check_correct_nncf_modules_replacement(model: NNCFNetwork, compressed_model: NNCFNetwork) \ -> Tuple[Dict[Scope, Module], Dict[Scope, Module]]: """ Check that all convolutions in model was replaced by NNCF convolution. :param model: original model :param compressed_model: compressed model :return: list of all convolutions in original model and list of all NNCF convolutions from compressed model """ NNCF_MODULES_REVERSED_MAP = {value: key for key, value in NNCF_MODULES_MAP.items()} original_modules = get_all_modules_by_type(model, list(NNCF_MODULES_MAP.values())) nncf_modules = get_all_modules_by_type(compressed_model.get_nncf_wrapped_model(), list(NNCF_MODULES_MAP.keys())) assert len(original_modules) == len(nncf_modules) print(original_modules, nncf_modules) for scope in original_modules.keys(): sparse_scope = deepcopy(scope) elt = sparse_scope.pop() # type: ScopeElement elt.calling_module_class_name = NNCF_MODULES_REVERSED_MAP[elt.calling_module_class_name] sparse_scope.push(elt) print(sparse_scope, nncf_modules) assert sparse_scope in nncf_modules return original_modules, nncf_modules
def _prune_weights(self, target_model: NNCFNetwork): grops_of_modules_to_prune = self._create_pruning_groups(target_model) device = next(target_model.parameters()).device insertion_commands = [] self.pruned_module_groups_info = Clusterization('module_scope') for i, group in enumerate(grops_of_modules_to_prune.get_all_clusters()): group_minfos = [] for node in group.nodes: module_scope, module = node.module_scope, node.module # Check that we need to prune weights in this op assert self._is_pruned_module(module) module_scope_str = str(module_scope) nncf_logger.info("Adding Weight Pruner in scope: {}".format(module_scope_str)) operation = self.create_weight_pruning_operation(module) hook = UpdateWeight(operation).to(device) insertion_commands.append( InsertionCommand( InsertionPoint(InsertionType.NNCF_MODULE_PRE_OP, module_scope=module_scope), hook, OperationPriority.PRUNING_PRIORITY ) ) related_modules = {} if self.prune_batch_norms: bn_module, bn_scope = get_bn_for_module_scope(target_model, module_scope) related_modules[PrunedModuleInfo.BN_MODULE_NAME] = BatchNormInfo(module_scope, bn_module, bn_scope) minfo = PrunedModuleInfo(module_scope, module, hook.operand, related_modules, node.id) group_minfos.append(minfo) cluster = NodesCluster(i, group_minfos, [n.id for n in group.nodes]) self.pruned_module_groups_info.add_cluster(cluster) return insertion_commands
def test_disable_shape_matching(): class MatMulModel(nn.Module): def __init__(self): super().__init__() self.dummy_param = torch.nn.Parameter(torch.ones([1])) def forward(self, inputs): half1, half2 = torch.chunk(inputs, 2, dim=2) return torch.bmm(half1, half2.transpose(1, 2)) model = MatMulModel() input_shape_1 = (3, 32, 32) input_shape_2 = (4, 64, 64) qnet_no_shape = NNCFNetwork( deepcopy(model), input_infos=[ ModelInputInfo(input_shape_1), ], scopes_without_shape_matching=['MatMulModel']) # type: NNCFNetwork _ = qnet_no_shape(torch.zeros(*input_shape_1)) graph_1 = deepcopy(qnet_no_shape.get_graph()) _ = qnet_no_shape(torch.zeros(*input_shape_2)) graph_2 = deepcopy(qnet_no_shape.get_graph()) keys_1 = list(graph_1.get_all_node_keys()) keys_2 = list(graph_2.get_all_node_keys()) assert len(keys_1) == 2 # 1 input node + 1 operation node assert keys_1 == keys_2 qnet = NNCFNetwork(model, input_infos=[ ModelInputInfo(input_shape_1), ]) # type: NNCFNetwork _ = qnet(torch.zeros(*input_shape_1)) _ = qnet(torch.zeros(*input_shape_2)) # The second forward run should have led to an increase in registered node counts # since disable_shape_matching was False and the network was run with a different # shape of input tensor assert qnet.get_graph().get_nodes_count() > graph_1.get_nodes_count()
def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): mask = nx_node['output_mask'] if mask is None: return bool_mask = torch.tensor(mask, dtype=torch.bool) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) old_num_clannels = int(node_module.weight.size(0)) node_module.out_channels = int(torch.sum(mask)) node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) if node_module.bias is not None and not nx_node['is_depthwise']: node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) nncf_logger.info( 'Pruned Convolution {} by pruning mask. Old output filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, node_module.out_channels))
def input_prune(self, model: NNCFNetwork, nx_node: dict, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) if isinstance(node_module, tuple(NNCF_WRAPPED_USER_MODULES_DICT)): assert node_module.target_weight_dim_for_compression == 0,\ "Implemented only for target_weight_dim_for_compression == 0" old_num_clannels = int(node_module.weight.size(0)) new_num_channels = int(torch.sum(input_mask)) node_module.weight = torch.nn.Parameter( node_module.weight[bool_mask]) node_module.n_channels = new_num_channels nncf_logger.info( 'Pruned Elementwise {} by input mask. Old num features: {}, new num features:' ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
def test_find_node_in_nx_graph_by_scope(): model = TwoConvTestModel() nncf_model = NNCFNetwork(deepcopy(model), input_infos=[ModelInputInfo([1, 1, 4, 4])]) # type: NNCFNetwork nncf_graph = nncf_model.get_original_graph() # Valid scopes should be successfully found valid_nncf_modules = nncf_model.get_nncf_modules() nodes_list = list(nncf_graph._nx_graph.nodes) for module_scope, _ in valid_nncf_modules.items(): graph_node = nncf_graph.find_node_in_nx_graph_by_scope(module_scope) assert graph_node is not None assert isinstance(graph_node, dict) assert graph_node['key'] in nodes_list fake_model = BasicConvTestModel() fake_nncf_model = NNCFNetwork(deepcopy(fake_model), input_infos=[ModelInputInfo([1, 1, 4, 4])]) # Not valid scopes shouldn't be found fake_nncf_modules = fake_nncf_model.get_nncf_modules() for module_scope, _ in fake_nncf_modules.items(): graph_node = nncf_graph.find_node_in_nx_graph_by_scope(module_scope) assert graph_node is None
def find_next_nodes_not_of_types(model: NNCFNetwork, nncf_node: NNCFNode, types: List[str]) -> List[NNCFNode]: """ Traverse nodes in the graph from nncf node to find first nodes that aren't of type from types list. First nodes with some condition mean nodes: - for which this condition is true - reachable from nncf_node such that on the path from nncf_node to this nodes there are no other nodes with fulfilled condition :param model: model to worh with :param nncf_node: NNCFNode to start search :param types: list of types :return: list of next nodes for nncf_node of type not from types list """ graph = model.get_original_graph() visited = {node_id: False for node_id in graph.get_all_node_idxs()} partial_traverse_function = partial(traverse_function, nncf_graph=graph, type_check_fn=lambda x: x not in types, visited=visited) nncf_nodes = [nncf_node] if nncf_node.op_exec_context.operator_name not in types: nncf_nodes = graph.get_next_nodes(nncf_node) next_nodes = [] for node in nncf_nodes: next_nodes.extend(graph.traverse_graph(node, partial_traverse_function)) return next_nodes
def apply_to(self, target_model: NNCFNetwork) -> NNCFNetwork: insertion_commands = self._prune_weights(target_model) for command in insertion_commands: target_model.register_insertion_command(command) target_model.register_algorithm(self) return target_model
def test_operator_metatype_marking(self): from nncf.dynamic_graph.operator_metatypes import Conv2dMetatype, BatchNormMetatype, RELUMetatype, \ MaxPool2dMetatype, \ ConvTranspose2dMetatype, DepthwiseConv2dSubtype, AddMetatype, AvgPool2dMetatype, LinearMetatype ref_scope_vs_metatype_dict = { "/" + MODEL_INPUT_OP_NAME + "_0": NoopMetatype, "ModelForMetatypeTesting/NNCFConv2d[conv_regular]/conv2d_0": Conv2dMetatype, "ModelForMetatypeTesting/BatchNorm2d[bn]/batch_norm_0": BatchNormMetatype, "ModelForMetatypeTesting/RELU_0": RELUMetatype, "ModelForMetatypeTesting/MaxPool2d[max_pool2d]/max_pool2d_0": MaxPool2dMetatype, "ModelForMetatypeTesting/NNCFConvTranspose2d[conv_transpose]/conv_transpose2d_0": ConvTranspose2dMetatype, "ModelForMetatypeTesting/NNCFConv2d[conv_depthwise]/conv2d_0": DepthwiseConv2dSubtype, "ModelForMetatypeTesting/__iadd___0": AddMetatype, "ModelForMetatypeTesting/AdaptiveAvgPool2d[adaptive_avg_pool]/adaptive_avg_pool2d_0": AvgPool2dMetatype, "ModelForMetatypeTesting/NNCFLinear[linear]/linear_0": LinearMetatype } class ModelForMetatypeTesting(torch.nn.Module): def __init__(self): super().__init__() self.conv_regular = torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3) self.bn = torch.nn.BatchNorm2d(num_features=16) self.max_pool2d = torch.nn.MaxPool2d(kernel_size=2) self.conv_transpose = torch.nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3) self.conv_depthwise = torch.nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5, groups=8) self.adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d( output_size=1) self.linear = torch.nn.Linear(in_features=8, out_features=1) def forward(self, input_): x = self.conv_regular(input_) x = self.bn(x) x = torch.nn.functional.relu(x) x.transpose_(2, 3) x = self.max_pool2d(x) x = self.conv_transpose(x) x = self.conv_depthwise(x) x += torch.ones_like(x) x = self.adaptive_avg_pool(x) x = self.linear(x.flatten()) return x model = ModelForMetatypeTesting() nncf_network = NNCFNetwork(model, [ModelInputInfo([1, 3, 300, 300])]) ip_graph = nncf_network.get_insertion_point_graph() for node in ip_graph.nodes().values(): if node[InsertionPointGraph. NODE_TYPE_NODE_ATTR] == InsertionPointGraphNodeType.OPERATOR: nncf_node_ref = node[ InsertionPointGraph.REGULAR_NODE_REF_NODE_ATTR] scope_str = str(nncf_node_ref[ NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].input_agnostic) assert scope_str in ref_scope_vs_metatype_dict ref_metatype = ref_scope_vs_metatype_dict[scope_str] assert node[InsertionPointGraph. OPERATOR_METATYPE_NODE_ATTR] == ref_metatype
def setup(self): self.compressed_model = NNCFNetwork( InsertionPointTestModel(), [ModelInputInfo([1, 1, 10, 10])]) # type: NNCFNetwork
class TestInsertionCommands: @pytest.fixture() def setup(self): self.compressed_model = NNCFNetwork( InsertionPointTestModel(), [ModelInputInfo([1, 1, 10, 10])]) # type: NNCFNetwork conv1_module_scope = Scope.from_str( 'InsertionPointTestModel/NNCFConv2d[conv1]') conv1_module_context = InputAgnosticOperationExecutionContext( '', conv1_module_scope, 0) point_for_conv1_weights = InsertionPoint( ia_op_exec_context=conv1_module_context, insertion_type=InsertionType.NNCF_MODULE_PRE_OP) point_for_conv1_inputs = InsertionPoint( ia_op_exec_context=conv1_module_context, insertion_type=InsertionType.NNCF_MODULE_PRE_OP) point_for_conv1_activations = InsertionPoint( ia_op_exec_context=conv1_module_context, insertion_type=InsertionType.NNCF_MODULE_POST_OP) conv2_module_scope = Scope.from_str( 'InsertionPointTestModel/NNCFConv2d[conv2]') conv2_module_context = InputAgnosticOperationExecutionContext( '', conv2_module_scope, 0) point_for_conv2_weights = InsertionPoint( ia_op_exec_context=conv2_module_context, insertion_type=InsertionType.NNCF_MODULE_PRE_OP) point_for_conv2_inputs = InsertionPoint( ia_op_exec_context=conv2_module_context, insertion_type=InsertionType.NNCF_MODULE_PRE_OP) point_for_conv2_activations = InsertionPoint( ia_op_exec_context=conv2_module_context, insertion_type=InsertionType.NNCF_MODULE_POST_OP) linear_op_scope = Scope.from_str('InsertionPointTestModel/linear_0') linear_op_context = InputAgnosticOperationExecutionContext( 'linear', linear_op_scope, 0) point_for_linear_weight_input = InsertionPoint( ia_op_exec_context=linear_op_context, insertion_type=InsertionType.OPERATOR_PRE_HOOK) point_for_linear_activation = InsertionPoint( ia_op_exec_context=linear_op_context, insertion_type=InsertionType.OPERATOR_POST_HOOK) relu_op_scope = Scope.from_str('InsertionPointTestModel/ReLU[relu]/relu') relu_op_context = InputAgnosticOperationExecutionContext( 'relu', relu_op_scope, 0) point_for_relu_inputs = InsertionPoint( ia_op_exec_context=relu_op_context, insertion_type=InsertionType.OPERATOR_PRE_HOOK) point_for_relu_activations = InsertionPoint( ia_op_exec_context=relu_op_context, insertion_type=InsertionType.OPERATOR_POST_HOOK) available_points = [ point_for_conv1_weights, point_for_conv2_weights, point_for_conv1_inputs, point_for_conv2_inputs, point_for_conv1_activations, point_for_conv2_activations, point_for_linear_activation, point_for_linear_weight_input, point_for_relu_activations, point_for_relu_inputs ] @pytest.mark.parametrize("insertion_point", available_points) def test_single_insertions(self, setup, insertion_point): if insertion_point.insertion_type in [ InsertionType.OPERATOR_PRE_HOOK, InsertionType.OPERATOR_POST_HOOK ]: hook = lambda x: x else: hook = BaseOp(lambda x: x) command = InsertionCommand(insertion_point, hook) self.compressed_model.register_insertion_command(command) self.compressed_model.commit_compression_changes() #pylint:disable=protected-access if insertion_point.insertion_type == InsertionType.OPERATOR_PRE_HOOK: ctx = self.compressed_model.get_tracing_context() assert ctx._pre_hooks[ command.insertion_point.ia_op_exec_context][0] is hook if insertion_point.insertion_type == InsertionType.OPERATOR_POST_HOOK: ctx = self.compressed_model.get_tracing_context() assert ctx._post_hooks[ command.insertion_point.ia_op_exec_context][0] is hook if insertion_point.insertion_type == InsertionType.NNCF_MODULE_PRE_OP: module = self.compressed_model.get_module_by_scope( command.insertion_point.ia_op_exec_context.scope_in_model) assert module.pre_ops["0"] is hook if insertion_point.insertion_type == InsertionType.NNCF_MODULE_POST_OP: module = self.compressed_model.get_module_by_scope( command.insertion_point.ia_op_exec_context.scope_in_model) assert module.post_ops["0"] is hook priority_types = ["same", "different"] insertion_types = InsertionType priority_test_cases = list( itertools.product(priority_types, insertion_types)) @staticmethod def check_order(iterable1: List, iterable2: List, ordering: List): for idx, order in enumerate(ordering): assert iterable1[idx] is iterable2[order] # pylint:disable=undefined-variable @pytest.mark.parametrize( "case", priority_test_cases, ids=[x[1].name + '-' + x[0] for x in priority_test_cases]) def test_priority(self, case, setup): #pylint:disable=too-many-branches priority_type = case[0] insertion_type = case[1] if insertion_type in [ InsertionType.NNCF_MODULE_PRE_OP, InsertionType.NNCF_MODULE_POST_OP ]: hook1 = BaseOp(lambda x: x) hook2 = BaseOp(lambda x: 2 * x) hook3 = BaseOp(lambda x: 3 * x) else: hook1 = lambda x: x hook2 = lambda x: 2 * x hook3 = lambda x: 3 * x if insertion_type == InsertionType.NNCF_MODULE_PRE_OP: point = self.point_for_conv2_weights elif insertion_type == InsertionType.NNCF_MODULE_POST_OP: point = self.point_for_conv1_activations elif insertion_type == InsertionType.OPERATOR_PRE_HOOK: point = self.point_for_linear_weight_input elif insertion_type == InsertionType.OPERATOR_POST_HOOK: point = self.point_for_relu_activations if priority_type == "same": # Same-priority commands will be executed in registration order command1 = InsertionCommand(point, hook1, OperationPriority.DEFAULT_PRIORITY) command2 = InsertionCommand(point, hook2, OperationPriority.DEFAULT_PRIORITY) command3 = InsertionCommand(point, hook3, OperationPriority.DEFAULT_PRIORITY) else: # Prioritized commands will be executed in ascending priority order command1 = InsertionCommand( point, hook1, OperationPriority.SPARSIFICATION_PRIORITY) command2 = InsertionCommand( point, hook2, OperationPriority.QUANTIZATION_PRIORITY) command3 = InsertionCommand(point, hook3, OperationPriority.DEFAULT_PRIORITY) self.compressed_model.register_insertion_command(command1) self.compressed_model.register_insertion_command(command2) self.compressed_model.register_insertion_command(command3) self.compressed_model.commit_compression_changes() hook_list = [hook1, hook2, hook3] if priority_type == "same": order = [0, 1, 2] elif priority_type == "different": order = [2, 0, 1] #pylint:disable=protected-access if insertion_type == InsertionType.OPERATOR_PRE_HOOK: ctx = self.compressed_model.get_tracing_context() self.check_order(ctx._pre_hooks[point.ia_op_exec_context], hook_list, order) if insertion_type == InsertionType.OPERATOR_POST_HOOK: ctx = self.compressed_model.get_tracing_context() self.check_order(ctx._post_hooks[point.ia_op_exec_context], hook_list, order) if insertion_type == InsertionType.NNCF_MODULE_PRE_OP: module = self.compressed_model.get_module_by_scope( point.ia_op_exec_context.scope_in_model) # Works because Pytorch ModuleDict is ordered self.check_order(list(module.pre_ops.values()), hook_list, order) if insertion_type == InsertionType.NNCF_MODULE_POST_OP: module = self.compressed_model.get_module_by_scope( point.ia_op_exec_context.scope_in_model) # Works because Pytorch ModuleDict is ordered self.check_order(list(module.post_ops.values()), hook_list, order)
def create_compressed_model(model: Module, config: NNCFConfig, resuming_state_dict: dict = None, dummy_forward_fn: Callable[[Module], Any] = None, wrap_inputs_fn: Callable[[Tuple, Dict], Tuple[Tuple, Dict]] = None, dump_graphs=True,) \ -> Tuple[CompressionAlgorithmController, NNCFNetwork]: """ The main function used to produce a model ready for compression fine-tuning from an original PyTorch model and a configuration object. dummy_forward_fn :param model: The original model. Should have its parameters already loaded from a checkpoint or another source. :param config: A configuration object used to determine the exact compression modifications to be applied to the model :param resuming_state_dict: A PyTorch state dict object to load (strictly) into the compressed model after building. :param dummy_forward_fn: if supplied, will be used instead of a *forward* function call to build the internal graph representation via tracing. Specifying this is useful when the original training pipeline has special formats of data loader output or has additional *forward* arguments other than input tensors. Otherwise, the *forward* call of the model during graph tracing will be made with mock tensors according to the shape specified in the config object. :param wrap_inputs_fn: if supplied, will be used on the module's input arguments during a regular, non-dummy forward call before passing the inputs to the underlying compressed model. This is required if the model's input tensors that are important for compression are not supplied as arguments to the model's forward call directly, but instead are located in a container (such as list), and the model receives the container as an argument. wrap_inputs_fn should take as input two arguments - the tuple of positional arguments to the underlying model's forward call, and a dict of keyword arguments to the same. The function should wrap each tensor among the supplied model's args and kwargs that is important for compression (e.g. quantization) with an nncf.nncf_model_input function, which is a no-operation function and marks the tensors as inputs to be traced by NNCF in the internal graph representation. Output is the tuple of (args, kwargs), where args and kwargs are the same as were supplied in input, but each tensor in the original input. :param dump_graphs: Whether or not should also dump the internal graph representation of the original and compressed models in the .dot format into the log directory. :return: A controller for the compression algorithm (or algorithms, in which case the controller is an instance of CompositeCompressionController) and the model ready for compression parameter training wrapped as an object of NNCFNetwork.""" # Compress model that will be deployed for the inference on target device. No need to compress parts of the # model that are used on training stage only (e.g. AuxLogits of Inception-v3 model) or unused modules with weights. # As a consequence, no need to care about spoiling BN statistics, as there're disabled in eval mode. model.eval() if dump_graphs: if dummy_forward_fn is None: input_info_list = create_input_infos(config) graph_builder = GraphBuilder( custom_forward_fn=create_dummy_forward_fn( input_info_list, with_input_tracing=True)) else: graph_builder = GraphBuilder(custom_forward_fn=dummy_forward_fn) if is_main_process(): graph = graph_builder.build_graph(model) graph.visualize_graph( osp.join(config.get("log_dir", "."), "original_graph.dot")) set_debug_log_dir(config.get("log_dir", ".")) input_info_list = create_input_infos(config) scopes_without_shape_matching = config.get('scopes_without_shape_matching', []) ignored_scopes = config.get('ignored_scopes') target_scopes = config.get('target_scopes') compressed_model = NNCFNetwork( model, input_infos=input_info_list, dummy_forward_fn=dummy_forward_fn, wrap_inputs_fn=wrap_inputs_fn, ignored_scopes=ignored_scopes, target_scopes=target_scopes, scopes_without_shape_matching=scopes_without_shape_matching) should_init = resuming_state_dict is None compression_algo_builder_list = create_compression_algorithm_builders( config, should_init=should_init) for builder in compression_algo_builder_list: compressed_model = builder.apply_to(compressed_model) compression_ctrl = compressed_model.commit_compression_changes() try: if resuming_state_dict is not None: load_state(compressed_model, resuming_state_dict, is_resume=True) finally: if dump_graphs and is_main_process() and compression_algo_builder_list: if dummy_forward_fn is None: compressed_graph_builder = GraphBuilder( custom_forward_fn=create_dummy_forward_fn( input_info_list, with_input_tracing=False)) else: compressed_graph_builder = GraphBuilder( custom_forward_fn=dummy_forward_fn) graph = compressed_graph_builder.build_graph( compressed_model, compressed_model.get_tracing_context()) graph.visualize_graph( osp.join(config.get("log_dir", "."), "compressed_graph.dot")) return compression_ctrl, compressed_model
def _prune_weights(self, target_model: NNCFNetwork): device = next(target_model.parameters()).device modules_to_prune = target_model.get_nncf_modules() insertion_commands = [] input_non_pruned_modules = get_first_pruned_modules( target_model, self.get_types_of_pruned_modules() + ['linear']) output_non_pruned_modules = get_last_pruned_modules( target_model, self.get_types_of_pruned_modules() + ['linear']) for module_scope, module in modules_to_prune.items(): # Check that we need to prune weights in this op if not self._is_pruned_module(module): continue module_scope_str = str(module_scope) if not self._should_consider_scope(module_scope_str): nncf_logger.info( "Ignored adding Weight Pruner in scope: {}".format( module_scope_str)) continue if not self.prune_first and module in input_non_pruned_modules: nncf_logger.info( "Ignored adding Weight Pruner in scope: {} because" " this scope is one of the first convolutions".format( module_scope_str)) continue if not self.prune_last and module in output_non_pruned_modules: nncf_logger.info( "Ignored adding Weight Pruner in scope: {} because" " this scope is one of the last convolutions".format( module_scope_str)) continue if not self.prune_downsample_convs and is_conv_with_downsampling( module): nncf_logger.info( "Ignored adding Weight Pruner in scope: {} because" " this scope is convolution with downsample".format( module_scope_str)) continue nncf_logger.info( "Adding Weight Pruner in scope: {}".format(module_scope_str)) operation = self.create_weight_pruning_operation(module) hook = UpdateWeight(operation).to(device) insertion_commands.append( InsertionCommand( InsertionPoint( InputAgnosticOperationExecutionContext( "", module_scope, 0), InsertionType.NNCF_MODULE_PRE_OP), hook, OperationPriority.PRUNING_PRIORITY)) related_modules = {} if self.prune_batch_norms: related_modules['bn_module'] = get_bn_for_module_scope( target_model, module_scope) self._pruned_module_info.append( PrunedModuleInfo(module_scope_str, module, hook.operand, related_modules)) return insertion_commands
def check_model_graph(compressed_model: NNCFNetwork, ref_dot_file_name: str, ref_dot_file_directory: str): compressed_model.to('cuda') compressed_model.do_dummy_forward() compressed_model.do_dummy_forward() check_graph(compressed_model.get_graph(), ref_dot_file_name, ref_dot_file_directory)
def _create_pruning_groups(self, target_model: NNCFNetwork): """ This function groups ALL modules with pruning types to groups that should be pruned together. 1. Create clusters for special ops (eltwises) that should be pruned together 2. Create groups of nodes that should be pruned together (taking into account clusters of special ops) 3. Add remaining single nodes 4. Unite clusters for Conv + Depthwise conv (should be pruned together too) 5. Checks for groups (all nodes in group can prune or all group can't be pruned) Return groups of modules that should be pruned together. :param target_model: model to work with :return: clusterisation of pruned nodes """ graph = target_model.get_original_graph() pruned_types = self.get_op_types_of_pruned_modules() all_modules_to_prune = target_model.get_nncf_modules_by_module_names(self.compressed_nncf_module_names) all_nodes_to_prune = graph.get_nodes_by_types(pruned_types) # NNCFNodes here assert len(all_nodes_to_prune) <= len(all_modules_to_prune) # 1. Clusters for special ops special_ops_types = self.get_types_of_grouping_ops() identity_like_types = IdentityMaskForwardOps.get_all_op_aliases() special_ops_clusterization = cluster_special_ops(target_model, special_ops_types, identity_like_types) pruned_nodes_clusterization = Clusterization("id") # 2. Clusters for nodes that should be pruned together (taking into account clusters for special ops) for i, cluster in enumerate(special_ops_clusterization.get_all_clusters()): all_pruned_inputs = [] pruned_inputs_idxs = set() for node in cluster.nodes: sources = get_sources_of_node(node, graph, pruned_types) for source_node in sources: source_scope = source_node.op_exec_context.scope_in_model source_module = target_model.get_module_by_scope(source_scope) source_node_info = NodeInfo(source_node, source_module, source_scope) if source_node.node_id not in pruned_inputs_idxs: all_pruned_inputs.append(source_node_info) pruned_inputs_idxs.add(source_node.node_id) if all_pruned_inputs: cluster = NodesCluster(i, list(all_pruned_inputs), [n.id for n in all_pruned_inputs]) pruned_nodes_clusterization.add_cluster(cluster) last_cluster_idx = len(special_ops_clusterization.get_all_clusters()) # 3. Add remaining single nodes as separate clusters for node in all_nodes_to_prune: if not pruned_nodes_clusterization.is_node_in_clusterization(node.node_id): scope = node.op_exec_context.scope_in_model module = target_model.get_module_by_scope(scope) node_info = NodeInfo(node, module, scope) cluster = NodesCluster(last_cluster_idx, [node_info], [node.node_id]) pruned_nodes_clusterization.add_cluster(cluster) last_cluster_idx += 1 # 4. Merge clusters for Conv + Depthwise conv (should be pruned together too) for node in all_nodes_to_prune: scope = node.op_exec_context.scope_in_model module = target_model.get_module_by_scope(scope) cluster_id = pruned_nodes_clusterization.get_cluster_by_node_id(node.node_id).id if is_depthwise_conv(module): previous_conv = get_previous_conv(target_model, module, scope) if previous_conv: previous_conv_cluster_id = pruned_nodes_clusterization.get_cluster_by_node_id( previous_conv.node_id).id pruned_nodes_clusterization.merge_clusters(cluster_id, previous_conv_cluster_id) # 5. Checks for groups (all nodes in group can be pruned or all group can't be pruned). model_analyser = ModelAnalyzer(target_model) can_prune_analysis = model_analyser.analyse_model_before_pruning() self._check_pruning_groups(target_model, pruned_nodes_clusterization, can_prune_analysis) return pruned_nodes_clusterization
def _prune_weights(self, target_model: NNCFNetwork): device = next(target_model.parameters()).device modules_to_prune = target_model.get_nncf_modules() insertion_commands = [] bn_for_depthwise = {} input_non_pruned_modules = get_first_pruned_modules( target_model, self.get_types_of_pruned_modules() + ['linear']) output_non_pruned_modules = get_last_pruned_modules( target_model, self.get_types_of_pruned_modules() + ['linear']) for module_scope, module in modules_to_prune.items(): # Check that we need to prune weights in this op if not self._is_pruned_module(module): continue module_scope_str = str(module_scope) if self.ignore_frozen_layers and not module.weight.requires_grad: nncf_logger.info( "Ignored adding Weight Pruner in scope: {} because" " the layer appears to be frozen (requires_grad=False)". format(module_scope_str)) continue if not self._should_consider_scope(module_scope_str): nncf_logger.info( "Ignored adding Weight Pruner in scope: {}".format( module_scope_str)) continue if not self.prune_first and module in input_non_pruned_modules: nncf_logger.info( "Ignored adding Weight Pruner in scope: {} because" " this scope is one of the first convolutions".format( module_scope_str)) continue if not self.prune_last and module in output_non_pruned_modules: nncf_logger.info( "Ignored adding Weight Pruner in scope: {} because" " this scope is one of the last convolutions".format( module_scope_str)) continue if is_grouped_conv(module): if is_depthwise_conv(module): previous_conv = get_previous_conv(target_model, module, module_scope) if previous_conv: depthwise_bn = get_bn_for_module_scope( target_model, module_scope) bn_for_depthwise[str(previous_conv.op_exec_context. scope_in_model)] = depthwise_bn nncf_logger.info( "Ignored adding Weight Pruner in scope: {} because" " this scope is grouped convolution".format( module_scope_str)) continue if not self.prune_downsample_convs and is_conv_with_downsampling( module): nncf_logger.info( "Ignored adding Weight Pruner in scope: {} because" " this scope is convolution with downsample".format( module_scope_str)) continue nncf_logger.info( "Adding Weight Pruner in scope: {}".format(module_scope_str)) operation = self.create_weight_pruning_operation(module) hook = UpdateWeight(operation).to(device) insertion_commands.append( InsertionCommand( InsertionPoint( InputAgnosticOperationExecutionContext( "", module_scope, 0), InsertionType.NNCF_MODULE_PRE_OP), hook, OperationPriority.PRUNING_PRIORITY)) related_modules = {} if self.prune_batch_norms: related_modules[ PrunedModuleInfo.BN_MODULE_NAME] = get_bn_for_module_scope( target_model, module_scope) self._pruned_module_info.append( PrunedModuleInfo(module_scope_str, module, hook.operand, related_modules)) if self.prune_batch_norms: self.update_minfo_with_depthwise_bn(bn_for_depthwise) return insertion_commands
def create_compressed_model(model: Module, config: NNCFConfig, resuming_state_dict: dict = None, dummy_forward_fn: Callable[[Module], Any] = None, dump_graphs=True,) \ -> Tuple[CompressionAlgorithmController, NNCFNetwork]: """ The main function used to produce a model ready for compression fine-tuning from an original PyTorch model and a configuration object. dummy_forward_fn :param model: The original model. Should have its parameters already loaded from a checkpoint or another source. :param config: A configuration object used to determine the exact compression modifications to be applied to the model :param resuming_state_dict: A PyTorch state dict object to load (strictly) into the compressed model after building. :param dummy_forward_fn: will be used instead of a *forward* function call to build the internal graph representation via tracing. Specifying this is useful when the original training pipeline has special formats of data loader output or has additional *forward* arguments other than input tensors. Otherwise, the *forward* call of the model during graph tracing will be made with mock tensors according to the shape specified in the config object. :param dump_graphs: Whether or not should also dump the internal graph representation of the original and compressed models in the .dot format into the log directory. :return: A controller for the compression algorithm (or algorithms, in which case the controller is an instance of CompositeCompressionController) and the model ready for compression parameter training wrapped as an object of NNCFNetwork.""" if dump_graphs: if dummy_forward_fn is None: input_info_list = create_input_infos(config) graph_builder = GraphBuilder( custom_forward_fn=create_dummy_forward_fn( input_info_list, with_input_tracing=True)) else: graph_builder = GraphBuilder(custom_forward_fn=dummy_forward_fn) if is_main_process(): graph = graph_builder.build_graph(model) graph.dump_graph(osp.join(config.get("log_dir", "."), "original_graph.dot"), extended=True) if is_debug(): set_debug_log_dir(config.get("log_dir", ".")) input_info_list = create_input_infos(config) scopes_without_shape_matching = config.get('scopes_without_shape_matching', []) ignored_scopes = config.get('ignored_scopes') target_scopes = config.get('target_scopes') compressed_model = NNCFNetwork( model, input_infos=input_info_list, dummy_forward_fn=dummy_forward_fn, ignored_scopes=ignored_scopes, target_scopes=target_scopes, scopes_without_shape_matching=scopes_without_shape_matching) should_init = resuming_state_dict is None compression_algo_builder_list = create_compression_algorithm_builders( config, should_init=should_init) for builder in compression_algo_builder_list: compressed_model = builder.apply_to(compressed_model) compression_ctrl = compressed_model.commit_compression_changes() if dump_graphs and is_main_process() and compression_algo_builder_list: if dummy_forward_fn is None: compressed_graph_builder = GraphBuilder( custom_forward_fn=create_dummy_forward_fn( input_info_list, with_input_tracing=False)) else: compressed_graph_builder = GraphBuilder( custom_forward_fn=dummy_forward_fn) graph = compressed_graph_builder.build_graph( compressed_model, compressed_model.get_tracing_context()) graph.dump_graph(osp.join(config.get("log_dir", "."), "compressed_graph.dot"), extended=True) if resuming_state_dict is not None: load_state(compressed_model, resuming_state_dict, is_resume=True) return compression_ctrl, compressed_model
def test_check_correct_modules_replacement(): model = TwoConvTestModel() nncf_model = NNCFNetwork(TwoConvTestModel(), input_infos=[ModelInputInfo([1, 1, 4, 4])]) # type: NNCFNetwork _, nncf_modules = check_correct_nncf_modules_replacement(model, nncf_model) assert set(nncf_modules) == set(nncf_model.get_nncf_modules())