def __init__(self, model, config_list, pruning_algorithm, optimizer=None, dependency_aware=False, dummy_input=None, **algo_kwargs): super().__init__(model, config_list, pruning_algorithm=pruning_algorithm, optimizer=optimizer, **algo_kwargs) self.dependency_aware = dependency_aware # set the dependency-aware switch for the masker self.masker.dependency_aware = dependency_aware self.dummy_input = dummy_input if self.dependency_aware: errmsg = "When dependency_aware is set, the dummy_input should not be None" assert self.dummy_input is not None, errmsg # Get the TorchModuleGraph of the target model # to trace the model, we need to unwrap the wrappers self._unwrap_model() self.graph = TorchModuleGraph(model, dummy_input) self._wrap_model() self.channel_depen = ChannelDependency( traced_model=self.graph.trace) self.group_depen = GroupDependency(traced_model=self.graph.trace) self.channel_depen = self.channel_depen.dependency_sets self.channel_depen = { name: sets for sets in self.channel_depen for name in sets } self.group_depen = self.group_depen.dependency_sets
def __init__(self, model=None, dummy_input=None, traced_model=None): """ Build the graph for the model. """ # check if the input is legal if traced_model is None: # user should provide model & dummy_input to trace # the model or a already traced model assert model is not None and dummy_input is not None self.graph = TorchModuleGraph(model, dummy_input, traced_model) self.dependency = dict() self.build_dependency()
def __init__(self, model=None, dummy_input=None, traced_model=None): """ This model analyze the channel dependencis between the conv layers in a model. Parameters ---------- model : torch.nn.Module The model to be analyzed. data : torch.Tensor The example input data to trace the network architecture. traced_model : torch._C.Graph if we alreay has the traced graph of the target model, we donnot need to trace the model again. """ # check if the input is legal if traced_model is None: # user should provide model & dummy_input to trace the model or a already traced model assert model is not None and dummy_input is not None self.graph = TorchModuleGraph(model, dummy_input, traced_model) self.dependency = dict() self.build_channel_dependency()
def test_module_reuse(self): class MyModule(nn.Module): def __init__(self): super().__init__() self.liner1 = nn.Linear(10, 10) self.relu = nn.ReLU(inplace=True) self.liner2 = nn.Linear(10, 20) self.liner3 = nn.Linear(20, 10) def forward(self, x): x = self.liner1(x) x = self.relu(x) x = self.liner2(x) x = self.relu(x) x = self.liner3(x) x = self.relu(x) return x data = torch.rand(10, 10) net = MyModule() traced = torch.jit.trace(net, data) modulegraph = TorchModuleGraph(traced_model=traced) # Traverse the TorchModuleGraph, due the resue of the relu module, # there will be three cpp_nodes corrspoding to the same module. # During traversing the graph, there should be only one # successor of each cpp-node (including the cpp_nodes that corresponds # to the same relu module). for name, nodeio in modulegraph.nodes_py.nodes_io.items(): if nodeio.input_or_output == 'input': # Find the first node of the whole graph start_nodes = modulegraph.input_to_node[name] # We have only one single path top-down assert len(start_nodes) == 1 node = start_nodes[0].unique_name while modulegraph.find_successors(node): nodes = modulegraph.find_successors(node) assert len(nodes) == 1 node = nodes[0]
def test_module_unpack(self): """ test the tuple/list unpack function of TorchModuleGraph. Following models are from the issue 2756 https://github.com/microsoft/nni/issues/2756. MyModule will have two successive tuple unpack operations between the B and C. """ class CBR(nn.Module): def __init__(self, i, o): super(CBR, self).__init__() self.conv1 = nn.Conv2d(i, o, kernel_size=1) self.bn1 = nn.BatchNorm2d(o) self.act1 = nn.ReLU() def forward(self, x): return self.act1(self.bn1(self.conv1(x))) class A(nn.Module): def __init__(self): super(A, self).__init__() self.conv1 = CBR( 3, 6, ) self.conv2 = CBR( 6, 8, ) self.conv3 = CBR(6, 12) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x1) x3 = self.conv3(x1) return (x2, x3) class B1(nn.Module): def __init__(self): super(B1, self).__init__() self.conv1 = CBR(12, 32) self.conv2 = CBR(32, 32) self.conv3 = CBR(32, 32) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x1) x3 = self.conv3(x2) return (x1, x2, x3) class B(nn.Module): def __init__(self): super(B, self).__init__() self.b = B1() def forward(self, x): return self.b(x[-1]) class C(nn.Module): def __init__(self): super(C, self).__init__() self.conv1 = CBR(8, 32) self.conv2 = CBR(12, 32) self.conv3 = CBR(32, 32) self.conv4 = CBR(32, 32) self.conv5 = CBR(32, 32) def forward(self, x): return (self.conv1(x[0]), self.conv2(x[1]), self.conv3(x[2]), self.conv4(x[3]), self.conv5(x[4])) class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.a = A() self.b = B() # self.dummy = Dummy() self.c = C() def forward(self, x): x_a = self.a(x) x_b = self.b(x_a) xc = self.c(x_a + x_b) return xc dummy_input = torch.rand(1, 3, 28, 28) model = MyModule() graph = TorchModuleGraph(model, dummy_input) graph.unpack_manually() for node in graph.nodes_py.nodes_op: # The input of the function nodes should # not come from the TupleUnpack node, because # all the TupleUnpack nodes have been removed(unpacked) # manually for _input in node.inputs: if _input in graph.output_to_node: preprocessor = graph.output_to_node[_input] assert preprocessor.op_type != TUPLE_UNPACK_KIND