def _is_module_prunable(self, graph: NNCFGraph, node: NNCFNode) -> PruningAnalysisDecision: """ Check whether we should prune module corresponding to provided node according to algorithm parameters. :param graph: Graph to work with. :param node: Node to check. :return: Pruning analysis decision. """ stop_propagation_ops = self._stop_propagation_op_metatype.get_all_op_aliases( ) types_to_track = self._prune_operations_types + stop_propagation_ops input_non_pruned_nodes = get_first_nodes_of_type(graph, types_to_track) node_name = node.node_name if not should_consider_scope(node_name, self._ignored_scopes, self._target_scopes): return PruningAnalysisDecision(False, PruningAnalysisReason.IGNORED_SCOPE) if not self._prune_first and node in input_non_pruned_nodes: return PruningAnalysisDecision(False, PruningAnalysisReason.FIRST_CONV) if is_grouped_conv(node) and not is_prunable_depthwise_conv(node): return PruningAnalysisDecision(False, PruningAnalysisReason.GROUP_CONV) if not self._prune_downsample_convs and is_conv_with_downsampling( node): return PruningAnalysisDecision( False, PruningAnalysisReason.DOWNSAMPLE_CONV) return PruningAnalysisDecision(True)
def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: input_mask = node.data['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) new_num_channels = int(torch.sum(input_mask)) is_depthwise = is_prunable_depthwise_conv(node) node_module = model.get_containing_module(node.node_name) old_num_channels = int(node_module.weight.size(1)) if is_depthwise: # In depthwise case prune output channels by input mask, here only fix for new number of input channels node_module.groups = new_num_channels node_module.in_channels = new_num_channels old_num_channels = int(node_module.weight.size(0)) else: out_channels = node_module.weight.size(0) broadcasted_mask = bool_mask.repeat(out_channels).view( out_channels, bool_mask.size(0)) new_weight_shape = list(node_module.weight.shape) new_weight_shape[1] = new_num_channels node_module.in_channels = new_num_channels node_module.weight = torch.nn.Parameter( node_module.weight[broadcasted_mask].view(new_weight_shape)) nncf_logger.info( 'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:' ' {}.'.format(node.data['key'], old_num_channels, new_num_channels))
def mask_propagation( cls, node: NNCFNode, graph: NNCFGraph, tensor_processor: Type[NNCFPruningBaseTensorProcessor]) -> None: input_masks = get_input_masks(node, graph) output_mask = node.data.get('output_mask', None) if is_grouped_conv(node): output_mask = None if is_prunable_depthwise_conv(node): output_mask = input_masks[0] node.data['output_mask'] = output_mask
def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLayout: """ Computes necessary model transformations (pruning mask insertions) to enable pruning. :param model: The original uncompressed model. :return: The instance of the `TransformationLayout` class containing a list of pruning mask insertions. """ converter = TFModelConverterFactory.create(model) self._graph = converter.convert() groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(self._graph) transformations = TFTransformationLayout() shared_layers = set() self._pruned_layer_groups_info = Clusterization[PrunedLayerInfo](lambda x: x.layer_name) for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()): group_minfos = [] for node in group.elements: layer_name = get_layer_identifier(node) layer = model.get_layer(layer_name) group_minfos.append(PrunedLayerInfo(node.node_name, layer_name, node.node_id, is_prunable_depthwise_conv(node))) # Add output_mask to elements to run mask_propagation # and detect spec_nodes that will be pruned. # It should be done for all elements of shared layer. node.data['output_mask'] = TFNNCFTensor(tf.ones(node.layer_attributes.out_channels)) if layer_name in shared_layers: continue if node.is_shared(): shared_layers.add(layer_name) # Check that we need to prune weights in this op assert self._is_pruned_layer(layer) nncf_logger.info('Adding Weight Pruner in: %s', layer_name) _, layer_info = converter.get_layer_info_for_node(node.node_name) for weight_def in node.metatype.weight_definitions: transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, weight_def.weight_attr_name) ) if node.metatype.bias_attr_name is not None and \ getattr(layer, node.metatype.bias_attr_name) is not None: transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, node.metatype.bias_attr_name) ) cluster = Cluster[PrunedLayerInfo](i, group_minfos, [n.node_id for n in group.elements]) self._pruned_layer_groups_info.add_cluster(cluster) # Propagating masks across the graph to detect spec_nodes that will be pruned mask_propagator = MaskPropagationAlgorithm(self._graph, TF_PRUNING_OPERATOR_METATYPES, TFNNCFPruningTensorProcessor) mask_propagator.mask_propagation() # Add masks for all spec modules, because prunable batchnorm layers can be determined # at the moment of mask propagation types_spec_layers = [TFBatchNormalizationLayerMetatype] \ if self._prune_batch_norms else [] spec_nodes = self._graph.get_nodes_by_metatypes(types_spec_layers) for spec_node in spec_nodes: layer_name = get_layer_identifier(spec_node) layer = model.get_layer(layer_name) if spec_node.data['output_mask'] is None: # Skip elements that will not be pruned continue if layer_name in shared_layers: continue if spec_node.is_shared(): shared_layers.add(layer_name) nncf_logger.info('Adding Weight Pruner in: %s', layer_name) _, layer_info = converter.get_layer_info_for_node(spec_node.node_name) for weight_def in spec_node.metatype.weight_definitions: if spec_node.metatype is TFBatchNormalizationLayerMetatype \ and not layer.scale and weight_def.weight_attr_name == 'gamma': nncf_logger.debug('Fused gamma parameter encountered in BatchNormalization layer. ' 'Do not add mask to it.') continue transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, weight_def.weight_attr_name) ) transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, spec_node.metatype.bias_attr_name) ) return transformations
def create_pruning_groups(self, graph: NNCFGraph) -> Clusterization[NNCFNode]: """ This function groups ALL nodes with pruning types to groups that should be pruned together. 1. Create clusters for special ops (eltwises) that should be pruned together 2. Create groups of nodes that should be pruned together (taking into account clusters of special ops) 3. Add remaining single nodes 4. Unite clusters for Conv + Depthwise conv (should be pruned together too) 5. Checks for groups (all nodes in group can prune or all group can't be pruned) Return groups of modules that should be pruned together. :param graph: Graph to work with and their initialization parameters as values. :return: Clusterization of pruned nodes. """ # pylint:disable=too-many-branches all_nodes_to_prune = graph.get_nodes_by_types( self._prune_operations_types) # NNCFNodes here # 1. Clusters for special ops identity_like_types = self._identity_mask_propagation_op_metatype.get_all_op_aliases( ) special_ops_clusterization = cluster_special_ops( graph, self._grouping_operations_types, identity_like_types) pruned_nodes_clusterization = Clusterization[NNCFNode]( lambda x: x.node_id) # 2. Clusters for nodes that should be pruned together (taking into account clusters for special ops) for i, cluster in enumerate( special_ops_clusterization.get_all_clusters()): all_pruned_inputs = {} clusters_to_merge = [] for node in cluster.elements: sources = get_sources_of_node(node, graph, self._prune_operations_types) for source_node in sources: if pruned_nodes_clusterization.is_node_in_clusterization( source_node.node_id): # Merge clusters if some node already added in another cluster cluster = pruned_nodes_clusterization.get_cluster_containing_element( source_node.node_id) clusters_to_merge.append(cluster.id) elif source_node.node_id not in all_pruned_inputs: all_pruned_inputs[source_node.node_id] = source_node if all_pruned_inputs: cluster = Cluster[NNCFNode](i, all_pruned_inputs.values(), all_pruned_inputs.keys()) clusters_to_merge.append(cluster.id) pruned_nodes_clusterization.add_cluster(cluster) # Merge clusters if one source node in several special ops clusters pruned_nodes_clusterization.merge_list_of_clusters( clusters_to_merge) last_cluster_idx = len(special_ops_clusterization.get_all_clusters()) # 3. Add remaining single nodes as separate clusters for node in all_nodes_to_prune: if not pruned_nodes_clusterization.is_node_in_clusterization( node.node_id): cluster = Cluster[NNCFNode](last_cluster_idx, [node], [node.node_id]) pruned_nodes_clusterization.add_cluster(cluster) last_cluster_idx += 1 stop_propagation_ops = self._stop_propagation_op_metatype.get_all_op_aliases( ) # 4. Merge clusters for Conv + Depthwise conv (should be pruned together too) for node in all_nodes_to_prune: cluster_id = pruned_nodes_clusterization.get_cluster_containing_element( node.node_id).id if is_prunable_depthwise_conv(node): previous_convs = get_previous_convs( graph, node, self._prune_operations_types, stop_propagation_ops) previous_clusters = [ pruned_nodes_clusterization.get_cluster_containing_element( node.node_id).id for node in previous_convs ] pruned_nodes_clusterization.merge_list_of_clusters( [cluster_id] + previous_clusters) # 5. Merge nodes into one cluster if some module forwards several times multiforward_nodes = self._get_multiforward_nodes(graph) for list_of_nodes in multiforward_nodes: clusters_to_merge = [ pruned_nodes_clusterization.get_cluster_containing_element( node.node_id).id for node in list_of_nodes ] pruned_nodes_clusterization.merge_list_of_clusters( clusters_to_merge) # Merge previous convolutions into one cluster all_previous_convs = [] for node in list_of_nodes: nncf_node = graph.get_node_by_id(node.node_id) previous_convs = get_previous_convs( graph, nncf_node, self._prune_operations_types, stop_propagation_ops) # Check if previous node isn't multiforward, # in case of multiforward nodes cycle for previous_conv in previous_convs: if previous_conv not in list_of_nodes: all_previous_convs.append(previous_conv) previous_clusters = [ pruned_nodes_clusterization.get_cluster_containing_element( node.node_id).id for node in all_previous_convs ] pruned_nodes_clusterization.merge_list_of_clusters( previous_clusters) # 6. Checks for groups (all nodes in group can be pruned or all group can't be pruned). model_analyser = ModelAnalyzer(graph, self._pruning_operator_metatypes, is_prunable_depthwise_conv) can_prune_analysis = model_analyser.analyse_model_before_pruning() can_prune_and_should_prune_analysis = self._should_prune_groups_analysis( graph, pruned_nodes_clusterization, can_prune_analysis) can_prune_final_analysis = self._pruning_dimensions_analysis( graph, can_prune_and_should_prune_analysis) self._filter_groups(pruned_nodes_clusterization, can_prune_final_analysis) return pruned_nodes_clusterization
def _prune_weights(self, target_model: NNCFNetwork): target_model_graph = target_model.get_original_graph() groups_of_nodes_to_prune = self.pruning_node_selector.create_pruning_groups( target_model_graph) device = next(target_model.parameters()).device insertion_commands = [] self.pruned_module_groups_info = Clusterization[PrunedModuleInfo]( lambda x: x.node_name) for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()): group_minfos = [] for node in group.elements: node_name = node.node_name module = target_model.get_containing_module(node_name) module_scope = target_model_graph.get_scope_by_node_name( node_name) # Check that we need to prune weights in this op assert self._is_pruned_module(module) nncf_logger.info( "Adding Weight Pruner in scope: {}".format(node_name)) pruning_block = self.create_weight_pruning_operation( module, node_name) # Hook for weights and bias hook = UpdateWeightAndBias(pruning_block).to(device) insertion_commands.append( PTInsertionCommand( PTTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name=node_name), hook, TransformationPriority.PRUNING_PRIORITY)) group_minfos.append( PrunedModuleInfo( node_name=node_name, module_scope=module_scope, module=module, operand=pruning_block, node_id=node.node_id, is_depthwise=is_prunable_depthwise_conv(node))) cluster = Cluster[PrunedModuleInfo]( i, group_minfos, [n.node_id for n in group.elements]) self.pruned_module_groups_info.add_cluster(cluster) # Propagate masks to find norm layers to prune init_output_masks_in_graph( target_model_graph, self.pruned_module_groups_info.get_all_nodes()) MaskPropagationAlgorithm( target_model_graph, PT_PRUNING_OPERATOR_METATYPES, PTNNCFPruningTensorProcessor).mask_propagation() # Adding binary masks also for Batch/Group Norms to allow applying masks after propagation types_to_apply_mask = ['group_norm'] if self.prune_batch_norms: types_to_apply_mask.append('batch_norm') all_norm_layers = target_model_graph.get_nodes_by_types( types_to_apply_mask) for node in all_norm_layers: if node.data['output_mask'] is None: # Skip elements that will not be pruned continue node_name = node.node_name module = target_model.get_containing_module(node_name) pruning_block = self.create_weight_pruning_operation( module, node_name) # Hook for weights and bias hook = UpdateWeightAndBias(pruning_block).to(device) insertion_commands.append( PTInsertionCommand( PTTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name=node_name), hook, TransformationPriority.PRUNING_PRIORITY)) self._pruned_norms_operators.append((node, pruning_block, module)) return insertion_commands
def accept_pruned_input(cls, node: NNCFNode) -> bool: if is_grouped_conv(node) and not is_prunable_depthwise_conv(node): return False return True