def speedup_model(self): """ There are basically two steps: first, do mask/shape inference, second, replace modules """ training = self.bound_model.training _logger.info("start to speed up the model") _logger.info("fix the mask conflict of the interdependent layers") fix_mask_conflict(self.masks, self.bound_model, self.dummy_input) _logger.info("infer module masks...") self.infer_modules_masks() _logger.info("replace compressed modules...") self.replace_compressed_modules() self.bound_model.train(training) _logger.info("speedup done")
def test_mask_conflict(self): outdir = os.path.join(prefix, 'masks') os.makedirs(outdir, exist_ok=True) for name in model_names: print('Test mask conflict for %s' % name) model = getattr(models, name) net = model().to(device) dummy_input = torch.ones(1, 3, 224, 224).to(device) # random generate the prune sparsity for each layer cfglist = [] for layername, layer in net.named_modules(): if isinstance(layer, nn.Conv2d): # pruner cannot allow the sparsity to be 0 or 1 sparsity = np.random.uniform(0.01, 0.99) cfg = { 'op_types': ['Conv2d'], 'op_names': [layername], 'sparsity': sparsity } cfglist.append(cfg) pruner = L1FilterPruner(net, cfglist) pruner.compress() ck_file = os.path.join(outdir, '%s.pth' % name) mask_file = os.path.join(outdir, '%s_mask' % name) pruner.export_model(ck_file, mask_file) pruner._unwrap_model() # Fix the mask conflict fixed_mask = fix_mask_conflict(mask_file, net, dummy_input) # use the channel dependency groud truth to check if # fix the mask conflict successfully for dset in channel_dependency_ground_truth[name]: lset = list(dset) for i, _ in enumerate(lset): assert fixed_mask[lset[0]]['weight'].size(0) == fixed_mask[ lset[i]]['weight'].size(0) w_index1 = self.get_pruned_index( fixed_mask[lset[0]]['weight']) w_index2 = self.get_pruned_index( fixed_mask[lset[i]]['weight']) assert w_index1 == w_index2 if hasattr(fixed_mask[lset[0]], 'bias'): b_index1 = self.get_pruned_index( fixed_mask[lset[0]]['bias']) b_index2 = self.get_pruned_index( fixed_mask[lset[i]]['bias']) assert b_index1 == b_index2
def compress(model, dummy, pruner_cls, config_list, ori_metric=1.00, metric_thres=0.01, sensitivity=None, trace=None, verbose=True): if sensitivity: config_list = update_sparsity_by_sensitivity(config_list, ori_metric, metric_thres, sensitivity) compressed_model = copy.deepcopy(model) pruner = pruner_cls(compressed_model, config_list) compressed_model = pruner.compress() mask_path = "/tmp/mask.pth" pruner.export_model(model_path='/tmp/model.pth', mask_path=mask_path) pruner._unwrap_model() print("fixing mask conflict...") fixed_mask = fix_mask_conflict(mask_path, compressed_model, dummy, trace) # mask = torch.load(mask_path) compressed_model.load_state_dict(model.state_dict()) apply_compression_results(compressed_model, fixed_mask) if verbose: count_zero(compressed_model, verbose=False) from thop import profile macs, params = profile(compressed_model, inputs=dummy, verbose=False) print("MACs: {} G, Params: {} M".format(macs / 1000000000, params / 100000)) speedup_model = speedup(compressed_model, dummy, fixed_mask, trace) if verbose: count_zero(speedup_model, verbose=False) return speedup_model, fixed_mask