def prune(self, threshold=None, sparsity=None): spec = pruning_lib.PruningSpec() if threshold: sens_path = pruning_lib.sens_path(self._graph) if not os.path.exists(sens_path): raise RuntimeError("Must call ana() before runnig prune.") net_sens = pruning_lib.read_sens(sens_path) # TODO(yuwang): Support excludes: important to detection net. net_sparsity = pruning_lib.get_sparsity_by_threshold( net_sens, threshold) logging.vlog( 1, 'NetSparsity: \n{}'.format('\n'.join( [str(group) for group in net_sparsity]))) for group_sparsity in net_sparsity: spec.add_group(group_sparsity) elif sparsity: groups = pruning_lib.group_nodes(self._graph) for group in groups: spec.add_group(pruning_lib.GroupSparsity(group, sparsity)) else: raise ValueError( "At least one of 'sparsity' or 'threshold' to be set") pruned_model, pruning_info = self._prune(self._graph, spec) return PruningModule(pruned_model, pruning_info)
def record(self, step, result): if step >= self._total_steps: raise IndexError logging.vlog(3, "step {} recodred as {}".format(step, result)) _, sparsity = self._eval_plan(step) self._metrics[step] = AnaMetric(sparsity, result)
def group_nodes(graph, nodes_to_exclude=[]): """Divide conv2d nodes into different groups. The nodes that connected with each other by elementwise operation will be divided into one group. """ node_group = node_group_lib.NodeGroup() for node in graph.nodes: if node.op.type == NNDCT_OP.CONV2D: node_group.add_node(node.name) for node in graph.nodes: if node.op.type != NNDCT_OP.ADD: continue eltwise_inputs = [] for name in node.in_nodes: input_node = graph.node(name) # Depthwise conv must be treated as a slave node. if input_node.op.type == NNDCT_OP.CONV2D and not is_depthwise( input_node.op): eltwise_inputs.append(name) else: ancestor = find_node_ancestor(graph, input_node, [NNDCT_OP.CONV2D], [NNDCT_OP.CONCAT]) if ancestor and not is_depthwise(input_node.op): eltwise_inputs.append(ancestor.name) if len(eltwise_inputs) < 2: continue logging.vlog( 2, "Union ({}, {})".format(eltwise_inputs[0], eltwise_inputs[1])) node_group.union(eltwise_inputs[0], eltwise_inputs[1]) # TODO(yuwang): Exclude group convolution nodes. # i.e. groups > 1 and groups != in_channels. all_groups = node_group.groups() groups = [] for group in all_groups: skip = False for node in nodes_to_exclude: if node in group: skip = True break if not skip: groups.append(group) return groups
def prune(self, pruning_spec): for index, group in enumerate(pruning_spec.groups): logging.vlog(1, "Group {}: {}".format(index, group)) node_pruning_results = {} for node in self._graph.nodes: node_pruning_results[node.name] = pruning_lib.NodePruningResult(node.name) if node.op.type != NNDCT_OP.CONV2D: continue group = pruning_spec.group(node.name) if not group: # Set out_dim even though the node is not going to be pruned. node_pruning_results[node.name].out_dim = node.op.attr['out_dim'] continue removed_outputs = [] removed_outputs, out_dim = self._get_pruned_filters( node.op.param['weights'], pruning_spec.channel_batch, group.sparsity) logging.vlog(3, 'node: {}, removed outputs: {}, out_dim: {}'.format( node.name, removed_outputs, out_dim)) for name in group.nodes: node_pruning_results[name] = pruning_lib.NodePruningResult( name, group.sparsity, removed_outputs, out_dim) node_pruning_results[node.name].master = True pruned_graph = self._generate_pruned_graph(node_pruning_results) for node in pruned_graph.nodes: node_pruning = node_pruning_results[node.name] for param, tensor in node.op.params.items(): org_tensor_shape = tensor.shape dim_size = len(tensor.shape) if dim_size == 1: out_axis, in_axis = 0, dim_size else: out_axis, in_axis = utils.tensor_out_in_axis(tensor) # The meaning of OI is not the same in nndct and pytorch. # In nndct, 'O' means channel multiplier (out_channels // in_channels) # and 'I' means in_channels. However, in pytorch, 'O' is out_channels # and 'I' is channel multiplier. # For example, the weight shape of depthwise conv in pytorch is # (32, 1, 3, 3) while in nndct the shape is (1, 3, 3, 32). # The weight data format in nndct is OHWI, that is, # (channel_multiplier, height, width, in_channels), we have to exchange # out_axis and in_axis so that the correct dimension can be removed. if pruning_lib.is_depthwise(node.op): out_axis, in_axis = in_axis, out_axis ndarray = tensor.data if node_pruning.removed_outputs: ndarray = np.delete(ndarray, node_pruning.removed_outputs, axis=out_axis) if node_pruning.removed_inputs and dim_size > in_axis: ndarray = np.delete(ndarray, node_pruning.removed_inputs, axis=in_axis) tensor.from_ndarray(ndarray) if org_tensor_shape != tensor.shape: logging.vlog(4, "Reset param of {}({}) {}: {} -> {}".format(node.name, node.op.type, param.name, org_tensor_shape, tensor.shape)) return pruned_graph, node_pruning_results
def prune(self, pruning_spec): for index, group in enumerate(pruning_spec.groups): logging.vlog(1, "Group {}: {}".format(index, group)) pruning_info = {} for node in self._graph.nodes: pruning_info[node.name] = pruning_lib.NodePruningInfo(node.name) if node.op.type != NNDCT_OP.CONV2D: continue group = pruning_spec.group(node.name) if not group: continue removed_outputs = [] removed_outputs, out_dim = self._get_pruned_filters( node.op.param['weight'], pruning_spec.channel_batch, group.sparsity) logging.vlog(3, 'Removed output channels of {}: {}'.format( node.name, removed_outputs)) pruning_info[node.name].master = True for name in group.nodes: pruning_info[name] = pruning_lib.NodePruningInfo( name, removed_outputs, out_dim) pruned_graph = self._generate_pruned_graph(pruning_info) for node in pruned_graph.nodes: node_pruning = pruning_info[node.name] for param, tensor in node.op.params.items(): org_tensor_shape = tensor.shape dim_size = len(tensor.shape) if dim_size == 1: out_axis, in_axis = 0, dim_size else: out_axis, in_axis = utils.tensor_out_in_axis(tensor) ndarray = tensor.data if node_pruning.removed_outputs: ndarray = np.delete(ndarray, node_pruning.removed_outputs, axis=out_axis) if node_pruning.removed_inputs and dim_size > in_axis: ndarray = np.delete(ndarray, node_pruning.removed_inputs, axis=in_axis) tensor.from_ndarray(ndarray) if org_tensor_shape != tensor.shape: logging.vlog(1, "Reset param of {}({}) {}: {} -> {}".format(node.name, node.op.type, param.name, org_tensor_shape, tensor.shape)) return pruned_graph, pruning_info
def record(self, step, result): if step >= self._total_steps: raise IndexError logging.vlog(3, "step {} recodred as {}".format(step, result)) self._results[step] = result