Example #1
0
 def __init__(self,
              model,
              config_list,
              pruning_algorithm,
              optimizer=None,
              dependency_aware=False,
              dummy_input=None,
              **algo_kwargs):
     super().__init__(model,
                      config_list,
                      pruning_algorithm=pruning_algorithm,
                      optimizer=optimizer,
                      **algo_kwargs)
     self.dependency_aware = dependency_aware
     # set the dependency-aware switch for the masker
     self.masker.dependency_aware = dependency_aware
     self.dummy_input = dummy_input
     if self.dependency_aware:
         errmsg = "When dependency_aware is set, the dummy_input should not be None"
         assert self.dummy_input is not None, errmsg
         # Get the TorchModuleGraph of the target model
         # to trace the model, we need to unwrap the wrappers
         self._unwrap_model()
         self.graph = TorchModuleGraph(model, dummy_input)
         self._wrap_model()
         self.channel_depen = ChannelDependency(
             traced_model=self.graph.trace)
         self.group_depen = GroupDependency(traced_model=self.graph.trace)
         self.channel_depen = self.channel_depen.dependency_sets
         self.channel_depen = {
             name: sets
             for sets in self.channel_depen for name in sets
         }
         self.group_depen = self.group_depen.dependency_sets
Example #2
0
 def test_channel_dependency(self):
     outdir = os.path.join(prefix, 'dependency')
     os.makedirs(outdir, exist_ok=True)
     for name in model_names:
         print('Analyze channel dependency for %s' % name)
         model = getattr(models, name)
         net = model().to(device)
         dummy_input = torch.ones(1, 3, 224, 224).to(device)
         channel_depen = ChannelDependency(net, dummy_input)
         depen_sets = channel_depen.dependency_sets
         d_set_count = 0
         for d_set in depen_sets:
             if len(d_set) > 1:
                 d_set_count += 1
                 assert d_set in channel_dependency_ground_truth[name]
         assert d_set_count == len(channel_dependency_ground_truth[name])
         fpath = os.path.join(outdir, name)
         channel_depen.export(fpath)
    def parse_model(self):
        """
        parse the model and find the target
        Parameters
        ----------
        model: torch.nn.Module
            the target model to predict the latency.
        dummy_input:
            the example input tensor for the model.

        """
        with torch.onnx.set_training(self.bound_model, False):
            # We need to trace the model in this way, else it will have problems
            traced = torch.jit.trace(self.bound_model, self.dummy_input)
        self.channel_depen = ChannelDependency(traced_model=traced)
        self.group_depen = GroupDependency(traced_model=traced)
        self.graph = self.channel_depen.graph
        self.name2module = {}
        self.filter_count = {}
        for name, module in self.bound_model.named_modules():
            self.name2module[name] = module
            if isinstance(module, nn.Conv2d):
                self.filter_count[name] = module.out_channels
        self.measured_data = []