Example #1
0
 def speedup_model(self):
     """
     There are basically two steps:
     first, do mask/shape inference,
     second, replace modules
     """
     training = self.bound_model.training
     _logger.info("start to speed up the model")
     _logger.info("fix the mask conflict of the interdependent layers")
     fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
     _logger.info("infer module masks...")
     self.infer_modules_masks()
     _logger.info("replace compressed modules...")
     self.replace_compressed_modules()
     self.bound_model.train(training)
     _logger.info("speedup done")
Example #2
0
    def test_mask_conflict(self):
        outdir = os.path.join(prefix, 'masks')
        os.makedirs(outdir, exist_ok=True)
        for name in model_names:
            print('Test mask conflict for %s' % name)
            model = getattr(models, name)
            net = model().to(device)
            dummy_input = torch.ones(1, 3, 224, 224).to(device)
            # random generate the prune sparsity for each layer
            cfglist = []
            for layername, layer in net.named_modules():
                if isinstance(layer, nn.Conv2d):
                    # pruner cannot allow the sparsity to be 0 or 1
                    sparsity = np.random.uniform(0.01, 0.99)
                    cfg = {
                        'op_types': ['Conv2d'],
                        'op_names': [layername],
                        'sparsity': sparsity
                    }
                    cfglist.append(cfg)
            pruner = L1FilterPruner(net, cfglist)
            pruner.compress()
            ck_file = os.path.join(outdir, '%s.pth' % name)
            mask_file = os.path.join(outdir, '%s_mask' % name)
            pruner.export_model(ck_file, mask_file)
            pruner._unwrap_model()
            # Fix the mask conflict
            fixed_mask = fix_mask_conflict(mask_file, net, dummy_input)

            # use the channel dependency groud truth to check if
            # fix the mask conflict successfully
            for dset in channel_dependency_ground_truth[name]:
                lset = list(dset)
                for i, _ in enumerate(lset):
                    assert fixed_mask[lset[0]]['weight'].size(0) == fixed_mask[
                        lset[i]]['weight'].size(0)
                    w_index1 = self.get_pruned_index(
                        fixed_mask[lset[0]]['weight'])
                    w_index2 = self.get_pruned_index(
                        fixed_mask[lset[i]]['weight'])
                    assert w_index1 == w_index2
                    if hasattr(fixed_mask[lset[0]], 'bias'):
                        b_index1 = self.get_pruned_index(
                            fixed_mask[lset[0]]['bias'])
                        b_index2 = self.get_pruned_index(
                            fixed_mask[lset[i]]['bias'])
                        assert b_index1 == b_index2
Example #3
0
def compress(model,
             dummy,
             pruner_cls,
             config_list,
             ori_metric=1.00,
             metric_thres=0.01,
             sensitivity=None,
             trace=None,
             verbose=True):
    if sensitivity:
        config_list = update_sparsity_by_sensitivity(config_list, ori_metric,
                                                     metric_thres, sensitivity)
    compressed_model = copy.deepcopy(model)
    pruner = pruner_cls(compressed_model, config_list)
    compressed_model = pruner.compress()

    mask_path = "/tmp/mask.pth"
    pruner.export_model(model_path='/tmp/model.pth', mask_path=mask_path)
    pruner._unwrap_model()

    print("fixing mask conflict...")
    fixed_mask = fix_mask_conflict(mask_path, compressed_model, dummy, trace)
    # mask = torch.load(mask_path)

    compressed_model.load_state_dict(model.state_dict())
    apply_compression_results(compressed_model, fixed_mask)
    if verbose:
        count_zero(compressed_model, verbose=False)
        from thop import profile
        macs, params = profile(compressed_model, inputs=dummy, verbose=False)
        print("MACs: {} G, Params: {} M".format(macs / 1000000000,
                                                params / 100000))
    speedup_model = speedup(compressed_model, dummy, fixed_mask, trace)
    if verbose:
        count_zero(speedup_model, verbose=False)
    return speedup_model, fixed_mask