def _prune(self, graph, pruning_spec): """A pruning runner function that generates a `PruningSpec` by given threshold and use `ChannelPruner` to perform pruning on the given model. Arguments: graph: A `NndctGraph` to be pruned. pruning_spec: A `PruningSpec` object indicates how to prune the model. Returns: A `torch.nn.Module` object rebuilt from the pruned `NndctGraph` model. A dict of `NodePruningInfo` which the key is the name of the node. """ pruner = pruner_lib.ChannelPruner(graph) pruned_graph, pruning_info = pruner.prune(pruning_spec) # TODO(yuwang): Support user-provided path. rebuilt_module = utils.rebuild_module(pruned_graph, './graph.py') # NodePruningInfo.state_dict_keys is only used to validate pruning result, # it has no effect on pruned graph. # NOTE: The current approach relies on TorchParser's implementation. # If the tensor's name no longer comes from the original module's # state dict key, then the following code will not work. for node in pruned_graph.nodes: node_pruning = pruning_info[node.name] node_pruning.state_dict_keys = [] for tensor in node.op.params.values(): # tensor.name -> state_dict_key # ResNet::conv1.weight -> conv1.weight node_pruning.state_dict_keys.append( tensor.name.lstrip(pruned_graph.name + '::')) return rebuilt_module, pruning_info
def __call__(self, input_queue, output_queue): # We have to reparse graph here to recreate the global variable # _NNDCT_OP_2_TORCH_OP. (obviously not a good idea) # TorchScriptWriter use this global map to generate the script. graph = parse_to_graph(self.module, self.input_specs) pruner = pruner_lib.ChannelPruner(graph) analyser = ana_lib.ModelAnalyser(graph) steps = analyser.steps() while not input_queue.empty(): cur_step = input_queue.get() spec = analyser.spec(cur_step) pruned_graph, _ = pruner.prune(spec) rebuilt_module, _ = utils.rebuild_module(pruned_graph) module = rebuilt_module.cuda() module.eval() score = self.eval_fn(module, *self.args).item() output_queue.put((cur_step, score)) logging.info('Analysis complete %d/%d' % (cur_step + 1, steps))
def _prune(self, graph, pruning_spec, output_script=None): """Use `ChannelPruner` to perform pruning on the given graph by given spec. Arguments: graph: A `NndctGraph` to be pruned. pruning_spec: A `PruningSpec` object indicates how to prune the model. output_script: Filepath that saves the generated script used for rebuilding model. If None, then the generated script will be written to a tempfile. Returns: A `torch.nn.Module` object rebuilt from the pruned `NndctGraph` model. A pruned nndct graph. A dict of `NodePruningResult` that indicates how each node is pruned. """ pruner = pruner_lib.ChannelPruner(graph) pruned_graph, pruning_info = pruner.prune(pruning_spec) rebuilt_module, filename = utils.rebuild_module(pruned_graph) if output_script: shutil.move(filename, output_script) # NOTE: The current approach relies on TorchParser's implementation. # If the tensor's name no longer comes from the original module's # state dict key, then the following code will not work. for node in pruned_graph.nodes: node_pruning = pruning_info[node.name] node_pruning.state_dict_keys = [] for tensor in node.op.params.values(): # tensor.name -> state_dict_key # ResNet::conv1.weight -> conv1.weight node_pruning.state_dict_keys.append( tensor.name.lstrip(pruned_graph.name + '::')) return rebuilt_module, pruned_graph, pruning_info