def speedup_model(self): """ There are basically two steps: first, do mask/shape inference, second, replace modules. """ _logger.info("start to speed up the model") self.initialize_speedup() training = self.bound_model.training # set to the evaluation mode self.bound_model.train(False) # TODO suppose to fix the conflict after the sparsity propagation # which is more elegent fix_mask_conflict(self.masks, self.bound_model, self.dummy_input) _logger.info("infer module masks...") self.infer_modules_masks() _logger.info('resolve the mask conflict') # load the original stat dict before replace the model self.bound_model.load_state_dict(self.ori_state_dict) _logger.info("replace compressed modules...") # the mask conflict should be already resolved self.replace_compressed_modules() self.bound_model.train(training) _logger.info("speedup done")
def test_mask_conflict(self): outdir = os.path.join(prefix, 'masks') os.makedirs(outdir, exist_ok=True) for name in model_names: print('Test mask conflict for %s' % name) model = getattr(models, name) net = model().to(device) dummy_input = torch.ones(1, 3, 224, 224).to(device) # random generate the prune sparsity for each layer cfglist = [] for layername, layer in net.named_modules(): if isinstance(layer, nn.Conv2d): # pruner cannot allow the sparsity to be 0 or 1 sparsity = np.random.uniform(0.01, 0.99) cfg = { 'op_types': ['Conv2d'], 'op_names': [layername], 'sparsity': sparsity } cfglist.append(cfg) pruner = L1FilterPruner(net, cfglist) pruner.compress() ck_file = os.path.join(outdir, '%s.pth' % name) mask_file = os.path.join(outdir, '%s_mask' % name) pruner.export_model(ck_file, mask_file) pruner._unwrap_model() # Fix the mask conflict fixed_mask = fix_mask_conflict(mask_file, net, dummy_input) # use the channel dependency groud truth to check if # fix the mask conflict successfully for dset in channel_dependency_ground_truth[name]: lset = list(dset) for i, _ in enumerate(lset): assert fixed_mask[lset[0]]['weight'].size(0) == fixed_mask[ lset[i]]['weight'].size(0) w_index1 = self.get_pruned_index( fixed_mask[lset[0]]['weight']) w_index2 = self.get_pruned_index( fixed_mask[lset[i]]['weight']) assert w_index1 == w_index2 if hasattr(fixed_mask[lset[0]], 'bias'): b_index1 = self.get_pruned_index( fixed_mask[lset[0]]['bias']) b_index2 = self.get_pruned_index( fixed_mask[lset[i]]['bias']) assert b_index1 == b_index2
def speedup_model(self): """ There are basically two steps: first, do mask/shape inference, second, replace modules """ training = self.bound_model.training _logger.info("start to speed up the model") _logger.info("fix the mask conflict of the interdependent layers") _, conv_prune_dim = fix_mask_conflict(self.masks, self.bound_model, self.dummy_input) set_conv_prune_dim(conv_prune_dim) _logger.info("infer module masks...") self.infer_modules_masks() _logger.info("replace compressed modules...") self.replace_compressed_modules() self.bound_model.train(training) _logger.info("speedup done")