示例#1
0
    def _prune(self, graph, pruning_spec):
        """A pruning runner function that generates a `PruningSpec` by given
    threshold and use `ChannelPruner` to perform pruning on the given model.

    Arguments:
      graph: A `NndctGraph` to be pruned.
      pruning_spec: A `PruningSpec` object indicates how to prune the model.

    Returns:
      A `torch.nn.Module` object rebuilt from the pruned `NndctGraph` model.
      A dict of `NodePruningInfo` which the key is the name of the node.
    """

        pruner = pruner_lib.ChannelPruner(graph)
        pruned_graph, pruning_info = pruner.prune(pruning_spec)

        # TODO(yuwang): Support user-provided path.
        rebuilt_module = utils.rebuild_module(pruned_graph, './graph.py')

        # NodePruningInfo.state_dict_keys is only used to validate pruning result,
        # it has no effect on pruned graph.
        # NOTE: The current approach relies on TorchParser's implementation.
        # If the tensor's name no longer comes from the original module's
        # state dict key, then the following code will not work.
        for node in pruned_graph.nodes:
            node_pruning = pruning_info[node.name]
            node_pruning.state_dict_keys = []
            for tensor in node.op.params.values():
                # tensor.name -> state_dict_key
                # ResNet::conv1.weight -> conv1.weight
                node_pruning.state_dict_keys.append(
                    tensor.name.lstrip(pruned_graph.name + '::'))

        return rebuilt_module, pruning_info
示例#2
0
    def __call__(self, input_queue, output_queue):
        # We have to reparse graph here to recreate the global variable
        # _NNDCT_OP_2_TORCH_OP. (obviously not a good idea)
        # TorchScriptWriter use this global map to generate the script.
        graph = parse_to_graph(self.module, self.input_specs)
        pruner = pruner_lib.ChannelPruner(graph)
        analyser = ana_lib.ModelAnalyser(graph)
        steps = analyser.steps()

        while not input_queue.empty():
            cur_step = input_queue.get()
            spec = analyser.spec(cur_step)
            pruned_graph, _ = pruner.prune(spec)

            rebuilt_module, _ = utils.rebuild_module(pruned_graph)
            module = rebuilt_module.cuda()
            module.eval()
            score = self.eval_fn(module, *self.args).item()
            output_queue.put((cur_step, score))
            logging.info('Analysis complete %d/%d' % (cur_step + 1, steps))
示例#3
0
    def _prune(self, graph, pruning_spec, output_script=None):
        """Use `ChannelPruner` to perform pruning on the given graph by given spec.

    Arguments:
      graph: A `NndctGraph` to be pruned.
      pruning_spec: A `PruningSpec` object indicates how to prune the model.
      output_script: Filepath that saves the generated script used for
        rebuilding model. If None, then the generated script will be written
        to a tempfile.

    Returns:
      A `torch.nn.Module` object rebuilt from the pruned `NndctGraph` model.
      A pruned nndct graph.
      A dict of `NodePruningResult` that indicates how each node is pruned.
    """

        pruner = pruner_lib.ChannelPruner(graph)
        pruned_graph, pruning_info = pruner.prune(pruning_spec)

        rebuilt_module, filename = utils.rebuild_module(pruned_graph)
        if output_script:
            shutil.move(filename, output_script)

        # NOTE: The current approach relies on TorchParser's implementation.
        # If the tensor's name no longer comes from the original module's
        # state dict key, then the following code will not work.
        for node in pruned_graph.nodes:
            node_pruning = pruning_info[node.name]
            node_pruning.state_dict_keys = []
            for tensor in node.op.params.values():
                # tensor.name -> state_dict_key
                # ResNet::conv1.weight -> conv1.weight
                node_pruning.state_dict_keys.append(
                    tensor.name.lstrip(pruned_graph.name + '::'))

        return rebuilt_module, pruned_graph, pruning_info