Exemplo n.º 1
0
 def __init__(self,
              model,
              config_list,
              pruning_algorithm,
              optimizer=None,
              dependency_aware=False,
              dummy_input=None,
              **algo_kwargs):
     super().__init__(model,
                      config_list,
                      pruning_algorithm=pruning_algorithm,
                      optimizer=optimizer,
                      **algo_kwargs)
     self.dependency_aware = dependency_aware
     # set the dependency-aware switch for the masker
     self.masker.dependency_aware = dependency_aware
     self.dummy_input = dummy_input
     if self.dependency_aware:
         errmsg = "When dependency_aware is set, the dummy_input should not be None"
         assert self.dummy_input is not None, errmsg
         # Get the TorchModuleGraph of the target model
         # to trace the model, we need to unwrap the wrappers
         self._unwrap_model()
         self.graph = TorchModuleGraph(model, dummy_input)
         self._wrap_model()
         self.channel_depen = ChannelDependency(
             traced_model=self.graph.trace)
         self.group_depen = GroupDependency(traced_model=self.graph.trace)
         self.channel_depen = self.channel_depen.dependency_sets
         self.channel_depen = {
             name: sets
             for sets in self.channel_depen for name in sets
         }
         self.group_depen = self.group_depen.dependency_sets
Exemplo n.º 2
0
 def __init__(self, model=None, dummy_input=None, traced_model=None):
     """
     Build the graph for the model.
     """
     # check if the input is legal
     if traced_model is None:
         # user should provide model & dummy_input to trace
         # the model or a already traced model
         assert model is not None and dummy_input is not None
     self.graph = TorchModuleGraph(model, dummy_input, traced_model)
     self.dependency = dict()
     self.build_dependency()
Exemplo n.º 3
0
    def __init__(self, model=None, dummy_input=None, traced_model=None):
        """
        This model analyze the channel dependencis between the conv
        layers in a model.

        Parameters
        ----------
        model : torch.nn.Module
            The model to be analyzed.
        data : torch.Tensor
            The example input data to trace the network architecture.
        traced_model : torch._C.Graph
            if we alreay has the traced graph of the target model, we donnot
            need to trace the model again.
        """
        # check if the input is legal
        if traced_model is None:
            # user should provide model & dummy_input to trace the model or a already traced model
            assert model is not None and dummy_input is not None
        self.graph = TorchModuleGraph(model, dummy_input, traced_model)
        self.dependency = dict()
        self.build_channel_dependency()
Exemplo n.º 4
0
    def test_module_reuse(self):
        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.liner1 = nn.Linear(10, 10)
                self.relu = nn.ReLU(inplace=True)
                self.liner2 = nn.Linear(10, 20)
                self.liner3 = nn.Linear(20, 10)

            def forward(self, x):
                x = self.liner1(x)
                x = self.relu(x)
                x = self.liner2(x)
                x = self.relu(x)
                x = self.liner3(x)
                x = self.relu(x)
                return x

        data = torch.rand(10, 10)
        net = MyModule()
        traced = torch.jit.trace(net, data)
        modulegraph = TorchModuleGraph(traced_model=traced)
        # Traverse the TorchModuleGraph, due the resue of the relu module,
        # there will be three cpp_nodes corrspoding to the same module.
        # During traversing the graph, there should be only one
        # successor of each cpp-node (including the cpp_nodes that corresponds
        # to the same relu module).
        for name, nodeio in modulegraph.nodes_py.nodes_io.items():
            if nodeio.input_or_output == 'input':
                # Find the first node of the whole graph
                start_nodes = modulegraph.input_to_node[name]
                # We have only one single path top-down
                assert len(start_nodes) == 1
                node = start_nodes[0].unique_name
                while modulegraph.find_successors(node):
                    nodes = modulegraph.find_successors(node)
                    assert len(nodes) == 1
                    node = nodes[0]
Exemplo n.º 5
0
class ChannelDependency:
    def __init__(self, model=None, dummy_input=None, traced_model=None):
        """
        This model analyze the channel dependencis between the conv
        layers in a model.

        Parameters
        ----------
        model : torch.nn.Module
            The model to be analyzed.
        data : torch.Tensor
            The example input data to trace the network architecture.
        traced_model : torch._C.Graph
            if we alreay has the traced graph of the target model, we donnot
            need to trace the model again.
        """
        # check if the input is legal
        if traced_model is None:
            # user should provide model & dummy_input to trace the model or a already traced model
            assert model is not None and dummy_input is not None
        self.graph = TorchModuleGraph(model, dummy_input, traced_model)
        self.dependency = dict()
        self.build_channel_dependency()

    def _get_parent_layers(self, node):
        """
        Find the nearest father conv layers for the target node.

        Parameters
        ---------
        node : torch._C.Node
            target node.

        Returns
        -------
        parent_layers: list
            nearest father conv/linear layers for the target worknode.
        """
        parent_layers = []
        queue = []
        queue.append(node)
        while queue:
            curnode = queue.pop(0)
            if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear':
                # find the first met conv
                parent_layers.append(curnode.name)
                continue
            parents = self.graph.find_predecessors(curnode.unique_name)
            parents = [self.graph.name_to_node[name] for name in parents]
            for parent in parents:
                queue.append(parent)
        return parent_layers

    def build_channel_dependency(self):
        """
        Build the channel dependency for the conv layers
        in the model.
        """
        for node in self.graph.nodes_py.nodes_op:
            parent_layers = []
            # find the node that contains aten::add
            # or aten::cat operations
            if node.op_type in ADD_TYPES:
                parent_layers = self._get_parent_layers(node)
            elif node.op_type == CAT_TYPE:
                # To determine if this cat operation will introduce channel
                # dependency, we need the specific input parameters of the cat
                # opertion. To get the input parameters of the cat opertion, we
                # need to traverse all the cpp_nodes included by this NodePyGroup,
                # because, TorchModuleGraph merges the important nodes and the adjacent
                # unimportant nodes (nodes started with prim::attr, for example) into a
                # NodepyGroup.
                cat_dim = None
                for cnode in node.node_cpps:
                    if cnode.kind() == CAT_TYPE:
                        cat_dim = list(cnode.inputs())[1].toIValue()
                        break
                if cat_dim != 1:
                    parent_layers = self._get_parent_layers(node)
            dependency_set = set(parent_layers)
            # merge the dependencies
            for parent in parent_layers:
                if parent in self.dependency:
                    dependency_set.update(self.dependency[parent])
            # save the dependencies
            for _node in dependency_set:
                self.dependency[_node] = dependency_set

    def export(self, filepath):
        """
        export the channel dependencies as a csv file.
        The layers at the same line have output channel
        dependencies with each other. For example,
        layer1.1.conv2, conv1, and layer1.0.conv2 have
        output channel dependencies with each other, which
        means the output channel(filters) numbers of these
        three layers should be same with each other, otherwise
        the model may has shape conflict.

        Output example:
        Dependency Set,Convolutional Layers
        Set 1,layer1.1.conv2,layer1.0.conv2,conv1
        Set 2,layer1.0.conv1
        Set 3,layer1.1.conv1
        """
        header = ['Dependency Set', 'Convolutional Layers']
        setid = 0
        visited = set()
        with open(filepath, 'w') as csvf:
            csv_w = csv.writer(csvf, delimiter=',')
            csv_w.writerow(header)
            for node in self.graph.nodes_py.nodes_op:
                if node.op_type != 'Conv2d' or node in visited:
                    continue
                setid += 1
                row = ['Set %d' % setid]
                if node.name not in self.dependency:
                    visited.add(node)
                    row.append(node.name)
                else:
                    for other in self.dependency[node.name]:
                        visited.add(self.graph.name_to_node[other])
                        row.append(other)
                csv_w.writerow(row)

    @property
    def dependency_sets(self):
        """
        Get the list of the dependency set.

        Returns
        -------
        dependency_sets : list
            list of the dependency sets. For example,
            [set(['conv1', 'conv2']), set(['conv3', 'conv4'])]

        """
        d_sets = []
        visited = set()
        for node in self.graph.nodes_py.nodes_op:
            if node.op_type != 'Conv2d' or node in visited:
                continue
            tmp_set = set()
            if node.name not in self.dependency:
                visited.add(node)
                tmp_set.add(node.name)
            else:
                for other in self.dependency[node.name]:
                    visited.add(self.graph.name_to_node[other])
                    tmp_set.add(other)
            d_sets.append(tmp_set)
        return d_sets
Exemplo n.º 6
0
    def test_module_unpack(self):
        """
        test the tuple/list unpack function of TorchModuleGraph.
        Following models are from the issue 2756
        https://github.com/microsoft/nni/issues/2756.
        MyModule will have two successive tuple unpack operations
        between the B and C.
        """
        class CBR(nn.Module):
            def __init__(self, i, o):
                super(CBR, self).__init__()
                self.conv1 = nn.Conv2d(i, o, kernel_size=1)
                self.bn1 = nn.BatchNorm2d(o)
                self.act1 = nn.ReLU()

            def forward(self, x):
                return self.act1(self.bn1(self.conv1(x)))

        class A(nn.Module):
            def __init__(self):
                super(A, self).__init__()
                self.conv1 = CBR(
                    3,
                    6,
                )
                self.conv2 = CBR(
                    6,
                    8,
                )
                self.conv3 = CBR(6, 12)

            def forward(self, x):
                x1 = self.conv1(x)
                x2 = self.conv2(x1)
                x3 = self.conv3(x1)
                return (x2, x3)

        class B1(nn.Module):
            def __init__(self):
                super(B1, self).__init__()
                self.conv1 = CBR(12, 32)
                self.conv2 = CBR(32, 32)
                self.conv3 = CBR(32, 32)

            def forward(self, x):
                x1 = self.conv1(x)
                x2 = self.conv2(x1)
                x3 = self.conv3(x2)
                return (x1, x2, x3)

        class B(nn.Module):
            def __init__(self):
                super(B, self).__init__()
                self.b = B1()

            def forward(self, x):
                return self.b(x[-1])

        class C(nn.Module):
            def __init__(self):
                super(C, self).__init__()
                self.conv1 = CBR(8, 32)
                self.conv2 = CBR(12, 32)
                self.conv3 = CBR(32, 32)
                self.conv4 = CBR(32, 32)
                self.conv5 = CBR(32, 32)

            def forward(self, x):
                return (self.conv1(x[0]), self.conv2(x[1]), self.conv3(x[2]),
                        self.conv4(x[3]), self.conv5(x[4]))

        class MyModule(nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                self.a = A()
                self.b = B()
                # self.dummy = Dummy()
                self.c = C()

            def forward(self, x):
                x_a = self.a(x)
                x_b = self.b(x_a)
                xc = self.c(x_a + x_b)
                return xc

        dummy_input = torch.rand(1, 3, 28, 28)
        model = MyModule()
        graph = TorchModuleGraph(model, dummy_input)
        graph.unpack_manually()
        for node in graph.nodes_py.nodes_op:
            # The input of the function nodes should
            # not come from the TupleUnpack node, because
            # all the TupleUnpack nodes have been removed(unpacked)
            # manually
            for _input in node.inputs:
                if _input in graph.output_to_node:
                    preprocessor = graph.output_to_node[_input]
                    assert preprocessor.op_type != TUPLE_UNPACK_KIND