def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level', dependency_aware=False, dummy_input=None, **algo_kwargs): super().__init__(model, config_list=config_list, optimizer=optimizer) self.dependency_aware = dependency_aware self.dummy_input = dummy_input if self.dependency_aware: if not self._supported_dependency_aware(): raise ValueError('This pruner does not support dependency aware!') errmsg = "When dependency_aware is set, the dummy_input should not be None" assert self.dummy_input is not None, errmsg # Get the TorchModuleGraph of the target model # to trace the model, we need to unwrap the wrappers self._unwrap_model() self.graph = TorchModuleGraph(model, dummy_input) self._wrap_model() self.channel_depen = ChannelDependency( traced_model=self.graph.trace) self.group_depen = GroupDependency(traced_model=self.graph.trace) self.channel_depen = self.channel_depen.dependency_sets self.channel_depen = { name: sets for sets in self.channel_depen for name in sets} self.group_depen = self.group_depen.dependency_sets self.masker = MASKER_DICT[pruning_algorithm]( model, self, **algo_kwargs) # set the dependency-aware switch for the masker self.masker.dependency_aware = dependency_aware self.set_wrappers_attribute("if_calculated", False)
def __init__(self, model, config_list, pruning_algorithm, optimizer=None, dependency_aware=False, dummy_input=None, **algo_kwargs): super().__init__(model, config_list, pruning_algorithm=pruning_algorithm, optimizer=optimizer, **algo_kwargs) self.dependency_aware = dependency_aware # set the dependency-aware switch for the masker self.masker.dependency_aware = dependency_aware self.dummy_input = dummy_input if self.dependency_aware: errmsg = "When dependency_aware is set, the dummy_input should not be None" assert self.dummy_input is not None, errmsg # Get the TorchModuleGraph of the target model # to trace the model, we need to unwrap the wrappers self._unwrap_model() self.graph = TorchModuleGraph(model, dummy_input) self._wrap_model() self.channel_depen = ChannelDependency( traced_model=self.graph.trace) self.group_depen = GroupDependency(traced_model=self.graph.trace) self.channel_depen = self.channel_depen.dependency_sets self.channel_depen = { name: sets for sets in self.channel_depen for name in sets } self.group_depen = self.group_depen.dependency_sets
def test_module_reuse(self): class MyModule(nn.Module): def __init__(self): super().__init__() self.liner1 = nn.Linear(10, 10) self.relu = nn.ReLU(inplace=True) self.liner2 = nn.Linear(10, 20) self.liner3 = nn.Linear(20, 10) def forward(self, x): x = self.liner1(x) x = self.relu(x) x = self.liner2(x) x = self.relu(x) x = self.liner3(x) x = self.relu(x) return x data = torch.rand(10, 10) net = MyModule() traced = torch.jit.trace(net, data) modulegraph = TorchModuleGraph(traced_model=traced) # Traverse the TorchModuleGraph, due the resue of the relu module, # there will be three cpp_nodes corrspoding to the same module. # During traversing the graph, there should be only one # successor of each cpp-node (including the cpp_nodes that corresponds # to the same relu module). for name, nodeio in modulegraph.nodes_py.nodes_io.items(): if nodeio.input_or_output == 'input': # Find the first node of the whole graph start_nodes = modulegraph.input_to_node[name] # We have only one single path top-down assert len(start_nodes) == 1 node = start_nodes[0].unique_name while modulegraph.find_successors(node): nodes = modulegraph.find_successors(node) assert len(nodes) == 1 node = nodes[0]
def __init__(self, model=None, dummy_input=None, traced_model=None): """ Build the graph for the model. """ from nni.common.graph_utils import TorchModuleGraph # check if the input is legal if traced_model is None: # user should provide model & dummy_input to trace # the model or a already traced model assert model is not None and dummy_input is not None self.graph = TorchModuleGraph(model, dummy_input, traced_model) self.dependency = dict() self.build_dependency()
def group_weight_names_by_graph(self): """ Populate self.attention_name_groups by running inference on the module graph. Currently, the group inferred AttentionWeightDependency is limited to a set of four weights, with the first three corresponding to Q_proj, K_proj, V_proj (in any order) and the last one being output_proj. """ try: module_graph = TorchModuleGraph(self.bound_model, self.dummy_input) dependency_tracer = AttentionWeightDependency( traced_model=module_graph.trace) self.attention_name_groups = dependency_tracer.dependency_sets self.group_weights_by_name() except Exception as e: raise RuntimeError( 'Graph trace failed: please check dummy_input, or specify attention_name_groups.\n' 'Exception message: ' + str(e))
def _get_dependency(self, dummy_input: Any): # get the channel dependency and group dependency # channel dependency format: [[module_name1, module_name2], [module_name3], ...] # group dependency format: {module_name: group_num} self.pruner._unwrap_model() graph = TorchModuleGraph(model=self.pruner.bound_model, dummy_input=dummy_input) channel_dependency = ChannelDependency( model=self.pruner.bound_model, dummy_input=dummy_input, traced_model=graph.trace).dependency_sets group_dependency = GroupDependency( model=self.pruner.bound_model, dummy_input=dummy_input, traced_model=graph.trace).dependency_sets self.pruner._wrap_model() return channel_dependency, group_dependency
def generate_graph(self, dummy_input: Any) -> TorchModuleGraph: """ Generate a `TorchModuleGraph` instance of `self.bound_model` based on `jit.trace`. Parameters ---------- dummy_input The dummy input for `jit.trace`, users should put it on right device before pass in. Returns ------- TorchModuleGraph A `TorchModuleGraph` instance. """ self._unwrap_model() graph = TorchModuleGraph(model=self.bound_model, dummy_input=dummy_input) self._wrap_model() return graph
def test_module_unpack(self): """ test the tuple/list unpack function of TorchModuleGraph. Following models are from the issue 2756 https://github.com/microsoft/nni/issues/2756. MyModule will have two successive tuple unpack operations between the B and C. """ class CBR(nn.Module): def __init__(self, i, o): super(CBR, self).__init__() self.conv1 = nn.Conv2d(i, o, kernel_size=1) self.bn1 = nn.BatchNorm2d(o) self.act1 = nn.ReLU() def forward(self, x): return self.act1(self.bn1(self.conv1(x))) class A(nn.Module): def __init__(self): super(A, self).__init__() self.conv1 = CBR( 3, 6, ) self.conv2 = CBR( 6, 8, ) self.conv3 = CBR(6, 12) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x1) x3 = self.conv3(x1) return (x2, x3) class B1(nn.Module): def __init__(self): super(B1, self).__init__() self.conv1 = CBR(12, 32) self.conv2 = CBR(32, 32) self.conv3 = CBR(32, 32) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x1) x3 = self.conv3(x2) return (x1, x2, x3) class B(nn.Module): def __init__(self): super(B, self).__init__() self.b = B1() def forward(self, x): return self.b(x[-1]) class C(nn.Module): def __init__(self): super(C, self).__init__() self.conv1 = CBR(8, 32) self.conv2 = CBR(12, 32) self.conv3 = CBR(32, 32) self.conv4 = CBR(32, 32) self.conv5 = CBR(32, 32) def forward(self, x): return (self.conv1(x[0]), self.conv2(x[1]), self.conv3(x[2]), self.conv4(x[3]), self.conv5(x[4])) class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.a = A() self.b = B() # self.dummy = Dummy() self.c = C() def forward(self, x): x_a = self.a(x) x_b = self.b(x_a) xc = self.c(x_a + x_b) return xc dummy_input = torch.rand(1, 3, 28, 28) model = MyModule() graph = TorchModuleGraph(model, dummy_input) graph.unpack_manually() for node in graph.nodes_py.nodes_op: # The input of the function nodes should # not come from the TupleUnpack node, because # all the TupleUnpack nodes have been removed(unpacked) # manually for _input in node.inputs: if _input in graph.output_to_node: preprocessor = graph.output_to_node[_input] assert preprocessor.op_type != TUPLE_UNPACK_KIND