コード例 #1
ファイル: one_shot.py プロジェクト: zhuqingling/nni
 def __init__(self,
     self.dependency_aware = dependency_aware
     # set the dependency-aware switch for the masker
     self.masker.dependency_aware = dependency_aware
     self.dummy_input = dummy_input
     if self.dependency_aware:
         errmsg = "When dependency_aware is set, the dummy_input should not be None"
         assert self.dummy_input is not None, errmsg
         # Get the TorchModuleGraph of the target model
         # to trace the model, we need to unwrap the wrappers
         self.graph = TorchModuleGraph(model, dummy_input)
         self.channel_depen = ChannelDependency(
         self.group_depen = GroupDependency(traced_model=self.graph.trace)
         self.channel_depen = self.channel_depen.dependency_sets
         self.channel_depen = {
             name: sets
             for sets in self.channel_depen for name in sets
         self.group_depen = self.group_depen.dependency_sets
コード例 #2
 def __init__(self, model=None, dummy_input=None, traced_model=None):
     Build the graph for the model.
     # check if the input is legal
     if traced_model is None:
         # user should provide model & dummy_input to trace
         # the model or a already traced model
         assert model is not None and dummy_input is not None
     self.graph = TorchModuleGraph(model, dummy_input, traced_model)
     self.dependency = dict()
コード例 #3
    def __init__(self, model=None, dummy_input=None, traced_model=None):
        This model analyze the channel dependencis between the conv
        layers in a model.

        model : torch.nn.Module
            The model to be analyzed.
        data : torch.Tensor
            The example input data to trace the network architecture.
        traced_model : torch._C.Graph
            if we alreay has the traced graph of the target model, we donnot
            need to trace the model again.
        # check if the input is legal
        if traced_model is None:
            # user should provide model & dummy_input to trace the model or a already traced model
            assert model is not None and dummy_input is not None
        self.graph = TorchModuleGraph(model, dummy_input, traced_model)
        self.dependency = dict()
コード例 #4
    def test_module_reuse(self):
        class MyModule(nn.Module):
            def __init__(self):
                self.liner1 = nn.Linear(10, 10)
                self.relu = nn.ReLU(inplace=True)
                self.liner2 = nn.Linear(10, 20)
                self.liner3 = nn.Linear(20, 10)

            def forward(self, x):
                x = self.liner1(x)
                x = self.relu(x)
                x = self.liner2(x)
                x = self.relu(x)
                x = self.liner3(x)
                x = self.relu(x)
                return x

        data = torch.rand(10, 10)
        net = MyModule()
        traced = torch.jit.trace(net, data)
        modulegraph = TorchModuleGraph(traced_model=traced)
        # Traverse the TorchModuleGraph, due the resue of the relu module,
        # there will be three cpp_nodes corrspoding to the same module.
        # During traversing the graph, there should be only one
        # successor of each cpp-node (including the cpp_nodes that corresponds
        # to the same relu module).
        for name, nodeio in modulegraph.nodes_py.nodes_io.items():
            if nodeio.input_or_output == 'input':
                # Find the first node of the whole graph
                start_nodes = modulegraph.input_to_node[name]
                # We have only one single path top-down
                assert len(start_nodes) == 1
                node = start_nodes[0].unique_name
                while modulegraph.find_successors(node):
                    nodes = modulegraph.find_successors(node)
                    assert len(nodes) == 1
                    node = nodes[0]
コード例 #5
    def test_module_unpack(self):
        test the tuple/list unpack function of TorchModuleGraph.
        Following models are from the issue 2756
        MyModule will have two successive tuple unpack operations
        between the B and C.
        class CBR(nn.Module):
            def __init__(self, i, o):
                super(CBR, self).__init__()
                self.conv1 = nn.Conv2d(i, o, kernel_size=1)
                self.bn1 = nn.BatchNorm2d(o)
                self.act1 = nn.ReLU()

            def forward(self, x):
                return self.act1(self.bn1(self.conv1(x)))

        class A(nn.Module):
            def __init__(self):
                super(A, self).__init__()
                self.conv1 = CBR(
                self.conv2 = CBR(
                self.conv3 = CBR(6, 12)

            def forward(self, x):
                x1 = self.conv1(x)
                x2 = self.conv2(x1)
                x3 = self.conv3(x1)
                return (x2, x3)

        class B1(nn.Module):
            def __init__(self):
                super(B1, self).__init__()
                self.conv1 = CBR(12, 32)
                self.conv2 = CBR(32, 32)
                self.conv3 = CBR(32, 32)

            def forward(self, x):
                x1 = self.conv1(x)
                x2 = self.conv2(x1)
                x3 = self.conv3(x2)
                return (x1, x2, x3)

        class B(nn.Module):
            def __init__(self):
                super(B, self).__init__()
                self.b = B1()

            def forward(self, x):
                return self.b(x[-1])

        class C(nn.Module):
            def __init__(self):
                super(C, self).__init__()
                self.conv1 = CBR(8, 32)
                self.conv2 = CBR(12, 32)
                self.conv3 = CBR(32, 32)
                self.conv4 = CBR(32, 32)
                self.conv5 = CBR(32, 32)

            def forward(self, x):
                return (self.conv1(x[0]), self.conv2(x[1]), self.conv3(x[2]),
                        self.conv4(x[3]), self.conv5(x[4]))

        class MyModule(nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                self.a = A()
                self.b = B()
                # self.dummy = Dummy()
                self.c = C()

            def forward(self, x):
                x_a = self.a(x)
                x_b = self.b(x_a)
                xc = self.c(x_a + x_b)
                return xc

        dummy_input = torch.rand(1, 3, 28, 28)
        model = MyModule()
        graph = TorchModuleGraph(model, dummy_input)
        for node in graph.nodes_py.nodes_op:
            # The input of the function nodes should
            # not come from the TupleUnpack node, because
            # all the TupleUnpack nodes have been removed(unpacked)
            # manually
            for _input in node.inputs:
                if _input in graph.output_to_node:
                    preprocessor = graph.output_to_node[_input]
                    assert preprocessor.op_type != TUPLE_UNPACK_KIND