コード例 #1
0
ファイル: one_shot.py プロジェクト: zhuqingling/nni
 def __init__(self,
              model,
              config_list,
              pruning_algorithm,
              optimizer=None,
              dependency_aware=False,
              dummy_input=None,
              **algo_kwargs):
     super().__init__(model,
                      config_list,
                      pruning_algorithm=pruning_algorithm,
                      optimizer=optimizer,
                      **algo_kwargs)
     self.dependency_aware = dependency_aware
     # set the dependency-aware switch for the masker
     self.masker.dependency_aware = dependency_aware
     self.dummy_input = dummy_input
     if self.dependency_aware:
         errmsg = "When dependency_aware is set, the dummy_input should not be None"
         assert self.dummy_input is not None, errmsg
         # Get the TorchModuleGraph of the target model
         # to trace the model, we need to unwrap the wrappers
         self._unwrap_model()
         self.graph = TorchModuleGraph(model, dummy_input)
         self._wrap_model()
         self.channel_depen = ChannelDependency(
             traced_model=self.graph.trace)
         self.group_depen = GroupDependency(traced_model=self.graph.trace)
         self.channel_depen = self.channel_depen.dependency_sets
         self.channel_depen = {
             name: sets
             for sets in self.channel_depen for name in sets
         }
         self.group_depen = self.group_depen.dependency_sets
コード例 #2
0
 def __init__(self, model=None, dummy_input=None, traced_model=None):
     """
     Build the graph for the model.
     """
     # check if the input is legal
     if traced_model is None:
         # user should provide model & dummy_input to trace
         # the model or a already traced model
         assert model is not None and dummy_input is not None
     self.graph = TorchModuleGraph(model, dummy_input, traced_model)
     self.dependency = dict()
     self.build_dependency()
コード例 #3
0
    def __init__(self, model=None, dummy_input=None, traced_model=None):
        """
        This model analyze the channel dependencis between the conv
        layers in a model.

        Parameters
        ----------
        model : torch.nn.Module
            The model to be analyzed.
        data : torch.Tensor
            The example input data to trace the network architecture.
        traced_model : torch._C.Graph
            if we alreay has the traced graph of the target model, we donnot
            need to trace the model again.
        """
        # check if the input is legal
        if traced_model is None:
            # user should provide model & dummy_input to trace the model or a already traced model
            assert model is not None and dummy_input is not None
        self.graph = TorchModuleGraph(model, dummy_input, traced_model)
        self.dependency = dict()
        self.build_channel_dependency()
コード例 #4
0
    def test_module_reuse(self):
        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.liner1 = nn.Linear(10, 10)
                self.relu = nn.ReLU(inplace=True)
                self.liner2 = nn.Linear(10, 20)
                self.liner3 = nn.Linear(20, 10)

            def forward(self, x):
                x = self.liner1(x)
                x = self.relu(x)
                x = self.liner2(x)
                x = self.relu(x)
                x = self.liner3(x)
                x = self.relu(x)
                return x

        data = torch.rand(10, 10)
        net = MyModule()
        traced = torch.jit.trace(net, data)
        modulegraph = TorchModuleGraph(traced_model=traced)
        # Traverse the TorchModuleGraph, due the resue of the relu module,
        # there will be three cpp_nodes corrspoding to the same module.
        # During traversing the graph, there should be only one
        # successor of each cpp-node (including the cpp_nodes that corresponds
        # to the same relu module).
        for name, nodeio in modulegraph.nodes_py.nodes_io.items():
            if nodeio.input_or_output == 'input':
                # Find the first node of the whole graph
                start_nodes = modulegraph.input_to_node[name]
                # We have only one single path top-down
                assert len(start_nodes) == 1
                node = start_nodes[0].unique_name
                while modulegraph.find_successors(node):
                    nodes = modulegraph.find_successors(node)
                    assert len(nodes) == 1
                    node = nodes[0]
コード例 #5
0
    def test_module_unpack(self):
        """
        test the tuple/list unpack function of TorchModuleGraph.
        Following models are from the issue 2756
        https://github.com/microsoft/nni/issues/2756.
        MyModule will have two successive tuple unpack operations
        between the B and C.
        """
        class CBR(nn.Module):
            def __init__(self, i, o):
                super(CBR, self).__init__()
                self.conv1 = nn.Conv2d(i, o, kernel_size=1)
                self.bn1 = nn.BatchNorm2d(o)
                self.act1 = nn.ReLU()

            def forward(self, x):
                return self.act1(self.bn1(self.conv1(x)))

        class A(nn.Module):
            def __init__(self):
                super(A, self).__init__()
                self.conv1 = CBR(
                    3,
                    6,
                )
                self.conv2 = CBR(
                    6,
                    8,
                )
                self.conv3 = CBR(6, 12)

            def forward(self, x):
                x1 = self.conv1(x)
                x2 = self.conv2(x1)
                x3 = self.conv3(x1)
                return (x2, x3)

        class B1(nn.Module):
            def __init__(self):
                super(B1, self).__init__()
                self.conv1 = CBR(12, 32)
                self.conv2 = CBR(32, 32)
                self.conv3 = CBR(32, 32)

            def forward(self, x):
                x1 = self.conv1(x)
                x2 = self.conv2(x1)
                x3 = self.conv3(x2)
                return (x1, x2, x3)

        class B(nn.Module):
            def __init__(self):
                super(B, self).__init__()
                self.b = B1()

            def forward(self, x):
                return self.b(x[-1])

        class C(nn.Module):
            def __init__(self):
                super(C, self).__init__()
                self.conv1 = CBR(8, 32)
                self.conv2 = CBR(12, 32)
                self.conv3 = CBR(32, 32)
                self.conv4 = CBR(32, 32)
                self.conv5 = CBR(32, 32)

            def forward(self, x):
                return (self.conv1(x[0]), self.conv2(x[1]), self.conv3(x[2]),
                        self.conv4(x[3]), self.conv5(x[4]))

        class MyModule(nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                self.a = A()
                self.b = B()
                # self.dummy = Dummy()
                self.c = C()

            def forward(self, x):
                x_a = self.a(x)
                x_b = self.b(x_a)
                xc = self.c(x_a + x_b)
                return xc

        dummy_input = torch.rand(1, 3, 28, 28)
        model = MyModule()
        graph = TorchModuleGraph(model, dummy_input)
        graph.unpack_manually()
        for node in graph.nodes_py.nodes_op:
            # The input of the function nodes should
            # not come from the TupleUnpack node, because
            # all the TupleUnpack nodes have been removed(unpacked)
            # manually
            for _input in node.inputs:
                if _input in graph.output_to_node:
                    preprocessor = graph.output_to_node[_input]
                    assert preprocessor.op_type != TUPLE_UNPACK_KIND