def __init__(self, model, config_list, pruning_algorithm, optimizer=None, dependency_aware=False, dummy_input=None, **algo_kwargs): super().__init__(model, config_list, pruning_algorithm=pruning_algorithm, optimizer=optimizer, **algo_kwargs) self.dependency_aware = dependency_aware # set the dependency-aware switch for the masker self.masker.dependency_aware = dependency_aware self.dummy_input = dummy_input if self.dependency_aware: errmsg = "When dependency_aware is set, the dummy_input should not be None" assert self.dummy_input is not None, errmsg # Get the TorchModuleGraph of the target model # to trace the model, we need to unwrap the wrappers self._unwrap_model() self.graph = TorchModuleGraph(model, dummy_input) self._wrap_model() self.channel_depen = ChannelDependency( traced_model=self.graph.trace) self.group_depen = GroupDependency(traced_model=self.graph.trace) self.channel_depen = self.channel_depen.dependency_sets self.channel_depen = { name: sets for sets in self.channel_depen for name in sets } self.group_depen = self.group_depen.dependency_sets
def __init__(self, model=None, dummy_input=None, traced_model=None): """ Build the graph for the model. """ # check if the input is legal if traced_model is None: # user should provide model & dummy_input to trace # the model or a already traced model assert model is not None and dummy_input is not None self.graph = TorchModuleGraph(model, dummy_input, traced_model) self.dependency = dict() self.build_dependency()
def __init__(self, model=None, dummy_input=None, traced_model=None): """ This model analyze the channel dependencis between the conv layers in a model. Parameters ---------- model : torch.nn.Module The model to be analyzed. data : torch.Tensor The example input data to trace the network architecture. traced_model : torch._C.Graph if we alreay has the traced graph of the target model, we donnot need to trace the model again. """ # check if the input is legal if traced_model is None: # user should provide model & dummy_input to trace the model or a already traced model assert model is not None and dummy_input is not None self.graph = TorchModuleGraph(model, dummy_input, traced_model) self.dependency = dict() self.build_channel_dependency()
def test_module_reuse(self): class MyModule(nn.Module): def __init__(self): super().__init__() self.liner1 = nn.Linear(10, 10) self.relu = nn.ReLU(inplace=True) self.liner2 = nn.Linear(10, 20) self.liner3 = nn.Linear(20, 10) def forward(self, x): x = self.liner1(x) x = self.relu(x) x = self.liner2(x) x = self.relu(x) x = self.liner3(x) x = self.relu(x) return x data = torch.rand(10, 10) net = MyModule() traced = torch.jit.trace(net, data) modulegraph = TorchModuleGraph(traced_model=traced) # Traverse the TorchModuleGraph, due the resue of the relu module, # there will be three cpp_nodes corrspoding to the same module. # During traversing the graph, there should be only one # successor of each cpp-node (including the cpp_nodes that corresponds # to the same relu module). for name, nodeio in modulegraph.nodes_py.nodes_io.items(): if nodeio.input_or_output == 'input': # Find the first node of the whole graph start_nodes = modulegraph.input_to_node[name] # We have only one single path top-down assert len(start_nodes) == 1 node = start_nodes[0].unique_name while modulegraph.find_successors(node): nodes = modulegraph.find_successors(node) assert len(nodes) == 1 node = nodes[0]
class ChannelDependency: def __init__(self, model=None, dummy_input=None, traced_model=None): """ This model analyze the channel dependencis between the conv layers in a model. Parameters ---------- model : torch.nn.Module The model to be analyzed. data : torch.Tensor The example input data to trace the network architecture. traced_model : torch._C.Graph if we alreay has the traced graph of the target model, we donnot need to trace the model again. """ # check if the input is legal if traced_model is None: # user should provide model & dummy_input to trace the model or a already traced model assert model is not None and dummy_input is not None self.graph = TorchModuleGraph(model, dummy_input, traced_model) self.dependency = dict() self.build_channel_dependency() def _get_parent_layers(self, node): """ Find the nearest father conv layers for the target node. Parameters --------- node : torch._C.Node target node. Returns ------- parent_layers: list nearest father conv/linear layers for the target worknode. """ parent_layers = [] queue = [] queue.append(node) while queue: curnode = queue.pop(0) if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear': # find the first met conv parent_layers.append(curnode.name) continue parents = self.graph.find_predecessors(curnode.unique_name) parents = [self.graph.name_to_node[name] for name in parents] for parent in parents: queue.append(parent) return parent_layers def build_channel_dependency(self): """ Build the channel dependency for the conv layers in the model. """ for node in self.graph.nodes_py.nodes_op: parent_layers = [] # find the node that contains aten::add # or aten::cat operations if node.op_type in ADD_TYPES: parent_layers = self._get_parent_layers(node) elif node.op_type == CAT_TYPE: # To determine if this cat operation will introduce channel # dependency, we need the specific input parameters of the cat # opertion. To get the input parameters of the cat opertion, we # need to traverse all the cpp_nodes included by this NodePyGroup, # because, TorchModuleGraph merges the important nodes and the adjacent # unimportant nodes (nodes started with prim::attr, for example) into a # NodepyGroup. cat_dim = None for cnode in node.node_cpps: if cnode.kind() == CAT_TYPE: cat_dim = list(cnode.inputs())[1].toIValue() break if cat_dim != 1: parent_layers = self._get_parent_layers(node) dependency_set = set(parent_layers) # merge the dependencies for parent in parent_layers: if parent in self.dependency: dependency_set.update(self.dependency[parent]) # save the dependencies for _node in dependency_set: self.dependency[_node] = dependency_set def export(self, filepath): """ export the channel dependencies as a csv file. The layers at the same line have output channel dependencies with each other. For example, layer1.1.conv2, conv1, and layer1.0.conv2 have output channel dependencies with each other, which means the output channel(filters) numbers of these three layers should be same with each other, otherwise the model may has shape conflict. Output example: Dependency Set,Convolutional Layers Set 1,layer1.1.conv2,layer1.0.conv2,conv1 Set 2,layer1.0.conv1 Set 3,layer1.1.conv1 """ header = ['Dependency Set', 'Convolutional Layers'] setid = 0 visited = set() with open(filepath, 'w') as csvf: csv_w = csv.writer(csvf, delimiter=',') csv_w.writerow(header) for node in self.graph.nodes_py.nodes_op: if node.op_type != 'Conv2d' or node in visited: continue setid += 1 row = ['Set %d' % setid] if node.name not in self.dependency: visited.add(node) row.append(node.name) else: for other in self.dependency[node.name]: visited.add(self.graph.name_to_node[other]) row.append(other) csv_w.writerow(row) @property def dependency_sets(self): """ Get the list of the dependency set. Returns ------- dependency_sets : list list of the dependency sets. For example, [set(['conv1', 'conv2']), set(['conv3', 'conv4'])] """ d_sets = [] visited = set() for node in self.graph.nodes_py.nodes_op: if node.op_type != 'Conv2d' or node in visited: continue tmp_set = set() if node.name not in self.dependency: visited.add(node) tmp_set.add(node.name) else: for other in self.dependency[node.name]: visited.add(self.graph.name_to_node[other]) tmp_set.add(other) d_sets.append(tmp_set) return d_sets
def test_module_unpack(self): """ test the tuple/list unpack function of TorchModuleGraph. Following models are from the issue 2756 https://github.com/microsoft/nni/issues/2756. MyModule will have two successive tuple unpack operations between the B and C. """ class CBR(nn.Module): def __init__(self, i, o): super(CBR, self).__init__() self.conv1 = nn.Conv2d(i, o, kernel_size=1) self.bn1 = nn.BatchNorm2d(o) self.act1 = nn.ReLU() def forward(self, x): return self.act1(self.bn1(self.conv1(x))) class A(nn.Module): def __init__(self): super(A, self).__init__() self.conv1 = CBR( 3, 6, ) self.conv2 = CBR( 6, 8, ) self.conv3 = CBR(6, 12) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x1) x3 = self.conv3(x1) return (x2, x3) class B1(nn.Module): def __init__(self): super(B1, self).__init__() self.conv1 = CBR(12, 32) self.conv2 = CBR(32, 32) self.conv3 = CBR(32, 32) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x1) x3 = self.conv3(x2) return (x1, x2, x3) class B(nn.Module): def __init__(self): super(B, self).__init__() self.b = B1() def forward(self, x): return self.b(x[-1]) class C(nn.Module): def __init__(self): super(C, self).__init__() self.conv1 = CBR(8, 32) self.conv2 = CBR(12, 32) self.conv3 = CBR(32, 32) self.conv4 = CBR(32, 32) self.conv5 = CBR(32, 32) def forward(self, x): return (self.conv1(x[0]), self.conv2(x[1]), self.conv3(x[2]), self.conv4(x[3]), self.conv5(x[4])) class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.a = A() self.b = B() # self.dummy = Dummy() self.c = C() def forward(self, x): x_a = self.a(x) x_b = self.b(x_a) xc = self.c(x_a + x_b) return xc dummy_input = torch.rand(1, 3, 28, 28) model = MyModule() graph = TorchModuleGraph(model, dummy_input) graph.unpack_manually() for node in graph.nodes_py.nodes_op: # The input of the function nodes should # not come from the TupleUnpack node, because # all the TupleUnpack nodes have been removed(unpacked) # manually for _input in node.inputs: if _input in graph.output_to_node: preprocessor = graph.output_to_node[_input] assert preprocessor.op_type != TUPLE_UNPACK_KIND