def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level', dependency_aware=False,
                 dummy_input=None, **algo_kwargs):
        super().__init__(model, config_list=config_list, optimizer=optimizer)

        self.dependency_aware = dependency_aware
        self.dummy_input = dummy_input

        if self.dependency_aware:
            if not self._supported_dependency_aware():
                raise ValueError('This pruner does not support dependency aware!')

            errmsg = "When dependency_aware is set, the dummy_input should not be None"
            assert self.dummy_input is not None, errmsg
            # Get the TorchModuleGraph of the target model
            # to trace the model, we need to unwrap the wrappers
            self._unwrap_model()
            self.graph = TorchModuleGraph(model, dummy_input)
            self._wrap_model()
            self.channel_depen = ChannelDependency(
                traced_model=self.graph.trace)
            self.group_depen = GroupDependency(traced_model=self.graph.trace)
            self.channel_depen = self.channel_depen.dependency_sets
            self.channel_depen = {
                name: sets for sets in self.channel_depen for name in sets}
            self.group_depen = self.group_depen.dependency_sets

        self.masker = MASKER_DICT[pruning_algorithm](
            model, self, **algo_kwargs)
        # set the dependency-aware switch for the masker
        self.masker.dependency_aware = dependency_aware
        self.set_wrappers_attribute("if_calculated", False)
Beispiel #2
0
 def __init__(self,
              model,
              config_list,
              pruning_algorithm,
              optimizer=None,
              dependency_aware=False,
              dummy_input=None,
              **algo_kwargs):
     super().__init__(model,
                      config_list,
                      pruning_algorithm=pruning_algorithm,
                      optimizer=optimizer,
                      **algo_kwargs)
     self.dependency_aware = dependency_aware
     # set the dependency-aware switch for the masker
     self.masker.dependency_aware = dependency_aware
     self.dummy_input = dummy_input
     if self.dependency_aware:
         errmsg = "When dependency_aware is set, the dummy_input should not be None"
         assert self.dummy_input is not None, errmsg
         # Get the TorchModuleGraph of the target model
         # to trace the model, we need to unwrap the wrappers
         self._unwrap_model()
         self.graph = TorchModuleGraph(model, dummy_input)
         self._wrap_model()
         self.channel_depen = ChannelDependency(
             traced_model=self.graph.trace)
         self.group_depen = GroupDependency(traced_model=self.graph.trace)
         self.channel_depen = self.channel_depen.dependency_sets
         self.channel_depen = {
             name: sets
             for sets in self.channel_depen for name in sets
         }
         self.group_depen = self.group_depen.dependency_sets
    def test_module_reuse(self):
        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.liner1 = nn.Linear(10, 10)
                self.relu = nn.ReLU(inplace=True)
                self.liner2 = nn.Linear(10, 20)
                self.liner3 = nn.Linear(20, 10)

            def forward(self, x):
                x = self.liner1(x)
                x = self.relu(x)
                x = self.liner2(x)
                x = self.relu(x)
                x = self.liner3(x)
                x = self.relu(x)
                return x

        data = torch.rand(10, 10)
        net = MyModule()
        traced = torch.jit.trace(net, data)
        modulegraph = TorchModuleGraph(traced_model=traced)
        # Traverse the TorchModuleGraph, due the resue of the relu module,
        # there will be three cpp_nodes corrspoding to the same module.
        # During traversing the graph, there should be only one
        # successor of each cpp-node (including the cpp_nodes that corresponds
        # to the same relu module).
        for name, nodeio in modulegraph.nodes_py.nodes_io.items():
            if nodeio.input_or_output == 'input':
                # Find the first node of the whole graph
                start_nodes = modulegraph.input_to_node[name]
                # We have only one single path top-down
                assert len(start_nodes) == 1
                node = start_nodes[0].unique_name
                while modulegraph.find_successors(node):
                    nodes = modulegraph.find_successors(node)
                    assert len(nodes) == 1
                    node = nodes[0]
Beispiel #4
0
    def __init__(self, model=None, dummy_input=None, traced_model=None):
        """
        Build the graph for the model.
        """
        from nni.common.graph_utils import TorchModuleGraph

        # check if the input is legal
        if traced_model is None:
            # user should provide model & dummy_input to trace
            # the model or a already traced model
            assert model is not None and dummy_input is not None
        self.graph = TorchModuleGraph(model, dummy_input, traced_model)
        self.dependency = dict()
        self.build_dependency()
Beispiel #5
0
    def group_weight_names_by_graph(self):
        """
        Populate self.attention_name_groups by running inference on the module graph.
        Currently, the group inferred AttentionWeightDependency is limited to a set of four weights, with the first
        three corresponding to Q_proj, K_proj, V_proj (in any order) and the last one being output_proj.
        """
        try:
            module_graph = TorchModuleGraph(self.bound_model, self.dummy_input)
            dependency_tracer = AttentionWeightDependency(
                traced_model=module_graph.trace)
            self.attention_name_groups = dependency_tracer.dependency_sets
            self.group_weights_by_name()

        except Exception as e:
            raise RuntimeError(
                'Graph trace failed: please check dummy_input, or specify attention_name_groups.\n'
                'Exception message: ' + str(e))
Beispiel #6
0
 def _get_dependency(self, dummy_input: Any):
     # get the channel dependency and group dependency
     # channel dependency format: [[module_name1, module_name2], [module_name3], ...]
     # group dependency format: {module_name: group_num}
     self.pruner._unwrap_model()
     graph = TorchModuleGraph(model=self.pruner.bound_model,
                              dummy_input=dummy_input)
     channel_dependency = ChannelDependency(
         model=self.pruner.bound_model,
         dummy_input=dummy_input,
         traced_model=graph.trace).dependency_sets
     group_dependency = GroupDependency(
         model=self.pruner.bound_model,
         dummy_input=dummy_input,
         traced_model=graph.trace).dependency_sets
     self.pruner._wrap_model()
     return channel_dependency, group_dependency
Beispiel #7
0
    def generate_graph(self, dummy_input: Any) -> TorchModuleGraph:
        """
        Generate a `TorchModuleGraph` instance of `self.bound_model` based on `jit.trace`.

        Parameters
        ----------
        dummy_input
            The dummy input for `jit.trace`, users should put it on right device before pass in.

        Returns
        -------
        TorchModuleGraph
            A `TorchModuleGraph` instance.
        """
        self._unwrap_model()
        graph = TorchModuleGraph(model=self.bound_model, dummy_input=dummy_input)
        self._wrap_model()
        return graph
    def test_module_unpack(self):
        """
        test the tuple/list unpack function of TorchModuleGraph.
        Following models are from the issue 2756
        https://github.com/microsoft/nni/issues/2756.
        MyModule will have two successive tuple unpack operations
        between the B and C.
        """
        class CBR(nn.Module):
            def __init__(self, i, o):
                super(CBR, self).__init__()
                self.conv1 = nn.Conv2d(i, o, kernel_size=1)
                self.bn1 = nn.BatchNorm2d(o)
                self.act1 = nn.ReLU()

            def forward(self, x):
                return self.act1(self.bn1(self.conv1(x)))

        class A(nn.Module):
            def __init__(self):
                super(A, self).__init__()
                self.conv1 = CBR(
                    3,
                    6,
                )
                self.conv2 = CBR(
                    6,
                    8,
                )
                self.conv3 = CBR(6, 12)

            def forward(self, x):
                x1 = self.conv1(x)
                x2 = self.conv2(x1)
                x3 = self.conv3(x1)
                return (x2, x3)

        class B1(nn.Module):
            def __init__(self):
                super(B1, self).__init__()
                self.conv1 = CBR(12, 32)
                self.conv2 = CBR(32, 32)
                self.conv3 = CBR(32, 32)

            def forward(self, x):
                x1 = self.conv1(x)
                x2 = self.conv2(x1)
                x3 = self.conv3(x2)
                return (x1, x2, x3)

        class B(nn.Module):
            def __init__(self):
                super(B, self).__init__()
                self.b = B1()

            def forward(self, x):
                return self.b(x[-1])

        class C(nn.Module):
            def __init__(self):
                super(C, self).__init__()
                self.conv1 = CBR(8, 32)
                self.conv2 = CBR(12, 32)
                self.conv3 = CBR(32, 32)
                self.conv4 = CBR(32, 32)
                self.conv5 = CBR(32, 32)

            def forward(self, x):
                return (self.conv1(x[0]), self.conv2(x[1]), self.conv3(x[2]),
                        self.conv4(x[3]), self.conv5(x[4]))

        class MyModule(nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                self.a = A()
                self.b = B()
                # self.dummy = Dummy()
                self.c = C()

            def forward(self, x):
                x_a = self.a(x)
                x_b = self.b(x_a)
                xc = self.c(x_a + x_b)
                return xc

        dummy_input = torch.rand(1, 3, 28, 28)
        model = MyModule()
        graph = TorchModuleGraph(model, dummy_input)
        graph.unpack_manually()
        for node in graph.nodes_py.nodes_op:
            # The input of the function nodes should
            # not come from the TupleUnpack node, because
            # all the TupleUnpack nodes have been removed(unpacked)
            # manually
            for _input in node.inputs:
                if _input in graph.output_to_node:
                    preprocessor = graph.output_to_node[_input]
                    assert preprocessor.op_type != TUPLE_UNPACK_KIND