コード例 #1
0
def prune_model(model):
    model.cpu()
    DG = tp.DependencyGraph().build_dependency(model,
                                               torch.randn(1, 30, 3, 250, 250))

    def prune_conv(conv, pruned_prob):
        weight = conv.weight.detach().cpu().numpy()
        out_channels = weight.shape[0]
        L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3))
        num_pruned = int(out_channels * pruned_prob)
        prune_index = np.argsort(
            L1_norm)[:num_pruned].tolist()  # remove filters with small L1-Norm
        plan = DG.get_pruning_plan(conv, tp.prune_conv, prune_index)
        plan.exec()

    block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
    blk_id = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            prune_fn = tp.prune_conv
        # elif isinstance(m, nn.BatchNorm2d):
        #     prune_fn = tp.prune_batchnorm
        prune_conv(m.conv1, block_prune_probs[blk_id])
        blk_id += 1
    print(model)
    return model
コード例 #2
0
    def random_prune(model, example_inputs, output_transform):
        model.cpu().eval()
        prunable_module_type = ( nn.Conv2d, nn.BatchNorm2d )
        prunable_modules = [ m for m in model.modules() if isinstance(m, prunable_module_type) ]
        ori_size = tp.utils.count_params( model )
        DG = tp.DependencyGraph().build_dependency( model, example_inputs=example_inputs, output_transform=output_transform )
        for layer_to_prune in prunable_modules:
            # select a layer
    
            if isinstance( layer_to_prune, nn.Conv2d ):
                prune_fn = tp.prune_conv
            elif isinstance(layer_to_prune, nn.BatchNorm2d):
                prune_fn = tp.prune_batchnorm

            ch = tp.utils.count_prunable_channels( layer_to_prune )
            rand_idx = random.sample( list(range(ch)), min( ch//2, 10 ) )
            plan = DG.get_pruning_plan( layer_to_prune, prune_fn, rand_idx)
            plan.exec()

        print(model)
        with torch.no_grad():
            out = model( example_inputs )
            if output_transform:
                out = output_transform(out)
            print(model_name)
            print( "  Params: %s => %s"%( ori_size, tp.utils.count_params(model) ) )
            print( "  Output: ", out.shape )
            print("------------------------------------------------------\n")
コード例 #3
0
ファイル: prune.py プロジェクト: jelenab98/swiftnet
def prune_lottery_ticket(model,
                         initial_model,
                         pruning_indices,
                         percentages,
                         BasicBlock,
                         p=1,
                         H=1024,
                         W=2048):
    initial_model.cpu()
    model.cpu()
    prune_stats = []

    def prune_conv(conv, idx, amount=0.2, p=1):
        strategy = tp.prune.strategy.LNStrategy(p)
        pruning_index = strategy.apply(conv.weight, amount)

        n_to_prune = max(int(amount * len(conv.weight)), 1)

        if isinstance(pruning_index, int):
            pruning_index = [pruning_index]

        stats = f"Layer {conv}, number to prune: {n_to_prune}, indices to prune: {pruning_index}"
        print(stats)
        prune_stats.append(stats)
        pruning_indices[idx] = (pruning_indices[idx][0],
                                merge_pruning_indices(pruning_indices[idx][0],
                                                      pruning_indices[idx][1],
                                                      pruning_index))
        #plan = DG.get_pruning_plan(conv_initial, tp.prune_conv, pruning_indices[idx][1])
        #plan.exec()

    i = 0
    j = 0
    for m in model.modules():
        if isinstance(m, BasicBlock):
            percentage = percentages[j]
            i += 2
            j += 1
            if percentage == 0:
                continue
            prune_conv(m.conv1, i - 2, percentage, p)
            prune_conv(m.conv2, i - 1, percentage, p)

    del model

    DG = tp.DependencyGraph().build_dependency(initial_model,
                                               torch.randn(1, 3, H, W))
    idx = 0
    for m in initial_model.modules():
        if isinstance(m, BasicBlock):
            plan = DG.get_pruning_plan(m.conv1, tp.prune_conv,
                                       pruning_indices[idx][1])
            plan.exec()
            plan = DG.get_pruning_plan(m.conv2, tp.prune_conv,
                                       pruning_indices[idx + 1][1])
            plan.exec()
            idx += 2
    return initial_model, prune_stats
コード例 #4
0
ファイル: utils.py プロジェクト: smocilac/structured_pruning
def hardprune_f(model, calc_prun_cand, pr_percentage=0.1):
    DG = pruning.DependencyGraph(model, fake_input=torch.randn(1, 3, 32, 32))
    # get a pruning plan according to the dependency graph. idxs is the indices of pruned filters.

    pruning_plan_list = calc_prun_cand(model, DG, pr_percentage)
    # execute this plan (prune the model)
    for i, pp in enumerate(pruning_plan_list):
        # print("=========== ", i, " ============ ")
        # print(pp)
        pp.exec()
コード例 #5
0
ファイル: ghostNet.py プロジェクト: the-cmyk/SandBox
def prune_model(model):
    model.cpu()
    DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 112, 112) )
    def prune_conv(conv, Norm_pruned_percent, GM_pruned_percent):
        weight = conv.weight.detach().cpu()
        out_channels = weight.numpy().shape[0]
        
        # Norm Criteria (L2)
        # for L1 Norm use: norm = np.sum( np.abs(weight.numpy()), axis=(1,2,3))
        norm = np.sum( np.square(weight.numpy()), axis=(1,2,3)) # no need to do square root, it is only for sorting
        num_norm_pruned = int(out_channels * Norm_pruned_percent)
        prune_index = np.argsort(norm)[:num_norm_pruned].tolist() # remove filters with small L2-Norm

        # Distance to GM Criteria (Prune layers closest to Geometric Distance)
        num_GM_pruned = int(out_channels * GM_pruned_percent)

        # indices of unprunned layers
        large_norm_index = []
        large_norm_index = np.argsort(norm)[num_norm_pruned:] # based on norm calculated in "Norm Criteria"
        indices = torch.LongTensor(large_norm_index).cuda()

        # isolate layer left layer indices
        weight_vec = weight.view(weight.size()[0], -1)
        weight_after_norm_prune = torch.index_select(weight_vec.cuda(), 0, indices).cpu().numpy()

        # Calculate distance matrix

        # for euclidean distance
        distance_matrix = distance.cdist(weight_after_norm_prune, weight_after_norm_prune, 'euclidean')
        # for cos similarity
        # distance_matrix = 1 - distance.cdist(weight_vec, weight_vecs, 'cosine')
        distance_sum = np.sum(np.abs(distance_matrix), axis=0)

        # for distance similar: get the filter index with largest similarity == small distance
        sorted_distances_index = distance_sum.argsort()[: num_GM_pruned]
        prune_index_GM = [large_norm_index[i] for i in sorted_distances_index]

        total_prune_index = prune_index + list(prune_index_GM)

        print("norm:", prune_index)
        print("GM:", prune_index_GM)
        print("total:", total_prune_index)

        plan = DG.get_pruning_plan(conv, tp.prune_conv, total_prune_index)
        plan.exec()
    
    block_prune_probs = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.40, 0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49]
    blk_id = 0
    for name, m in model.named_modules():
        if isinstance( m, nn.Conv2d ):
          print(name)
          prune_conv( m, 0.2, 0.2 ) 
    return model
コード例 #6
0
def prune_model(name='',
                model=None,
                dir_models='',
                suffix='_pruned',
                im_size=224):
    print('\nPruning Model: ' + name + '...', end='\t')

    model.to(device)

    strategy = tp.strategy.L1Strategy()
    DG = tp.DependencyGraph().build_dependency(model,
                                               example_inputs=torch.randn(
                                                   1, 3, im_size, im_size))

    def prune_conv(conv, amount=0.2):
        pruning_index = strategy(conv.weight, amount=amount)
        plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
        plan.exec()

    block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
    # block_prune_probs = [0.2, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3]
    limit = len(block_prune_probs) - 1
    blk_id = 0
    conv_index = 0
    prev_conv = False
    for m in list(model.modules())[1:]:
        if isinstance(m, modules.Conv2d):
            prev_conv = True
            conv_index += 1
        if isinstance(m, modules.Linear) and prev_conv:
            break

    for m in list(model.modules()):
        if isinstance(m, modules.Conv2d) and 'resnet' not in name.lower():
            if 'naber' not in name.lower() or ('naber' in name.lower()
                                               and conv_index > 1):
                prune_conv(m)
                conv_index -= 1
        if isinstance(m, BasicBlock) or isinstance(m, Bottleneck):
            prune_conv(m.conv1, block_prune_probs[blk_id])
            prune_conv(m.conv2, block_prune_probs[blk_id])
            if blk_id < limit - 1:
                blk_id += 1

    print('COMPLETE')

    # 5. Save Model
    print('Saving', name + suffix + '...', end='\t')
    filename = dir_models + name + suffix + ".pth"
    torch.save(model, filename)
    print('COMPLETE')
    return torch.load(filename)
コード例 #7
0
def prune_model(model, prune_prob = 0.1):
    model.cpu()
    DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
    def prune_conv(conv, amount=0.2):
        strategy = tp.strategy.L1Strategy()
        pruning_index = strategy(conv.weight, amount=amount)
        plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
        plan.exec()
    
    for _, m in model.named_modules():
        if isinstance(m, torch.nn.Conv2d):
            prune_conv(m, prune_prob)
    return model
コード例 #8
0
    def __call__(self, model, rate=0.1, example_inputs=None):
        if example_inputs is None:
            example_inputs = torch.randn(1, 3, 256, 256)

        DG = tp.DependencyGraph()
        DG.build_dependency(model, example_inputs=example_inputs)

        prunable_layers = []
        total_params = 0
        num_accumulative_conv_params = [
            0,
        ]

        for m in model.modules():
            if isinstance(m, _PRUNABLE_MODULES):
                nparam = tp.utils.count_prunable_params(m)
                total_params += nparam
                if isinstance(m, (nn.modules.conv._ConvNd, nn.Linear)):
                    prunable_layers.append(m)
                    num_accumulative_conv_params.append(
                        num_accumulative_conv_params[-1] + nparam)
        prunable_layers.pop(-1)  # remove the last layer
        num_accumulative_conv_params.pop(-1)  # remove the last layer

        num_conv_params = num_accumulative_conv_params[-1]
        num_accumulative_conv_params = [
            (num_accumulative_conv_params[i],
             num_accumulative_conv_params[i + 1])
            for i in range(len(num_accumulative_conv_params) - 1)
        ]

        def map_param_idx_to_conv_layer(i):
            for l, accu in zip(prunable_layers, num_accumulative_conv_params):
                if accu[0] <= i and i < accu[1]:
                    return l

        num_pruned = 0
        while num_pruned < total_params * rate:
            layer_to_prune = map_param_idx_to_conv_layer(
                random.randint(0, num_conv_params - 1))
            if layer_to_prune.weight.shape[0] < 1:
                continue
            idx = self.select(layer_to_prune)
            fn = tp.prune_conv if isinstance(
                layer_to_prune, nn.modules.conv._ConvNd) else tp.prune_linear
            plan = DG.get_pruning_plan(layer_to_prune, fn, idxs=idx)
            num_pruned += plan.exec()
        return model
コード例 #9
0
ファイル: prune.py プロジェクト: wxy1234567/Resnet50-pruning
def prune_model(model,num_list):
    model.to(device)
    DG = tp.DependencyGraph().build_dependency(model, torch.randn(1, 3, 224, 224) )
    def prune_bn(bn, num):
        L1_norm = bn.weight.detach().cpu().numpy()
        prune_index = np.argsort(L1_norm)[:num].tolist() # remove filters with small L1-Norm
        plan = DG.get_pruning_plan(bn, tp.prune_batchnorm, prune_index)
        plan.exec()
    
    blk_id = 0
    for m in model.modules():
        if isinstance( m, torchvision.models.resnet.Bottleneck ):
            prune_bn( m.bn1, num_list[blk_id] )
            prune_bn( m.bn2, num_list[blk_id+1] )
            blk_id+=2
    return model  
コード例 #10
0
def prune_model(model):
    model.cpu()
    DG = tp.DependencyGraph().build_dependency(model, torch.randn(1, 3, 32, 32))

    def prune_conv(conv, amount=0.2):
        strategy = tp.strategy.L1Strategy()
        pruning_index = strategy(conv.weight, amount=amount)
        plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
        plan.exec()

    block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
    blk_id = 0
    for m in model.modules():
        if isinstance(m, BasicBlock):
            prune_conv(m.conv1, block_prune_probs[blk_id])
            prune_conv(m.conv2, block_prune_probs[blk_id])
            blk_id += 1
    return model
コード例 #11
0
ファイル: demo.py プロジェクト: wxy1234567/Resnet50-pruning
def prune_model(model):
    # model.cpu()
    # DG = pruning.DependencyGraph().build_dependency( model, torch.randn(1, 3, 224, 224) )
    # def prune_conv(conv, num_pruned):
    #     weight = conv.weight.detach().cpu().numpy()
    #     L1_norm = np.abs(weight)
    #     prune_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
    #     plan = DG.get_pruning_plan(conv, pruning.prune_batchnorm, prune_index)
    #     # print(plan)
    #     plan.exec()

    # blk_id = 0
    # for m in model.modules():
    #     #print(m)
    #     if isinstance(m, torchvision.models.resnet.Bottleneck):
    #         prune_conv(m.bn1, 10)
    #         prune_conv(m.bn2, 10)

    # return model
    model.cpu()
    DG = pruning.DependencyGraph().build_dependency(
        model, torch.randn(1, 3, 224, 224))

    def prune_conv(conv, pruned_prob):
        weight = conv.weight.detach().cpu().numpy()
        out_channels = weight.shape[0]
        L1_norm = np.abs(weight)
        # L1_norm = np.sum(np.abs(weight), axis=(1,2,3))
        num_pruned = int(out_channels * pruned_prob)
        prune_index = np.argsort(
            L1_norm)[:num_pruned].tolist()  # remove filters with small L1-Norm
        plan = DG.get_pruning_plan(conv, pruning.prune_batchnorm, prune_index)
        plan.exec()

    block_prune_probs = [0.8] * 16
    blk_id = 0
    for m in model.modules():
        if isinstance(m, torchvision.models.resnet.Bottleneck):
            prune_conv(m.bn1, block_prune_probs[blk_id])
            prune_conv(m.bn2, block_prune_probs[blk_id])
            blk_id += 1
    return model
コード例 #12
0
def prune_model(model):
    model.cpu()
    DG = pruning.DependencyGraph(model, fake_input=torch.randn(1, 3, 32, 32))

    def prune_conv(conv, pruned_prob):
        weight = conv.weight.detach().cpu().numpy()
        out_channels = weight.shape[0]
        L1_norm = np.sum(weight, axis=(1, 2, 3))
        num_pruned = int(out_channels * pruned_prob)
        prune_index = np.argsort(
            L1_norm)[:num_pruned].tolist()  # remove filters with small L1-Norm

        plan = DG.get_pruning_plan(conv, pruning.prune_conv, prune_index)
        plan.exec()

    block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
    blk_id = 0
    for m in model.modules():
        if isinstance(m, resnet.BasicBlock):
            prune_conv(m.conv1, block_prune_probs[blk_id])
            prune_conv(m.conv2, block_prune_probs[blk_id])
            blk_id += 1
    return model
コード例 #13
0
ファイル: pruning.py プロジェクト: Raul9595/PruneAway
def prune_model_structured(model, device, amt):
    model.cpu()
    dummy_input = torch.randn(1, 3, 32, 32)

    DG = tp.DependencyGraph().build_dependency(model.module, dummy_input)

    def prune_conv(conv, amount=0.015):
        strategy = tp.strategy.L1Strategy()
        pruning_index = strategy(conv.weight, amount=amount)
        plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
        plan.exec()

    block_id = 0
    for m in model.module.modules():
        if isinstance(m, Bottleneck):
            try:
                prune_conv(m.conv1, amt)
                prune_conv(m.conv2, amt)
                prune_conv(m.conv3, amt)
            except:
                continue
            block_id += 1
    return model.to(device)
コード例 #14
0
        print("Average Inference time in ms:", average_inference_time)
        inference_time_matrix.append(average_inference_time)

        # This only make the prune permanant, the size of the model remains the same
        # The pruned weight are set to be zero. This won't affect the inference time
        # This is required to make the pruning package work.
        for module in conv_modules:
            prune.remove(module, 'weight')
        print("Number of params after removing mask:",
              get_num_params(conv_modules))
        for module in conv_modules:
            filter_sum_weight = module.weight.sum(
                dim=(1, 2, 3)).detach().cpu().numpy()
            pruning_idxs = np.where(filter_sum_weight == 0)[0].tolist()
            # print(pruning_idxs)
            DG = pruning.DependencyGraph()
            DG.build_dependency(model_pruned,
                                example_inputs=torch.randn(1, 3, 224, 224))
            pruning_plan = DG.get_pruning_plan(module,
                                               pruning.prune_conv,
                                               idxs=pruning_idxs)
            pruning_plan.exec()

        print("==================== Actual pruned Model ====================")
        num_params = get_num_params(conv_modules)
        print("Number of params:", num_params)
        num_params_list[3].append(num_params)
        print("Model:")
        print(model_pruned)

        model_pruned.to(DEVICE)
def prune_model(model):
    '''Prune the model.

    Prunes the model using Torch_pruning toolkit.
    Produces the accuracy after pruning and the number of pruned parameters.

    Args:
        model: a instance of class ResNet18 whose parameters has been trained
    '''
    DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
    
    def plot_filters_single_channel(t, index):
        
        #kernels depth * number of kernels
        nplots = t.shape[0]*t.shape[1]
        ncolumns = 12
        
        nrows = 1 + nplots//ncolumns
        #convert tensor to numpy image
        npimage = np.array(t.numpy(), np.float32)
        
        count = 0
        fig = plt.figure(figsize=(ncolumns, nrows))
        
        #looping through all the kernels in each channel
        for i in range(t.shape[0]):
            if i not in index:
              continue
            for j in range(t.shape[1]):
                count += 1
                ax1 = fig.add_subplot(nrows, ncolumns, count)
                npimage = np.array(t[i, j].numpy(), np.float32)
                npimage = (npimage - np.mean(npimage)) / np.std(npimage)
                npimage = np.minimum(1, np.maximum(0, (npimage + 0.5)))
                ax1.imshow(npimage)
                ax1.set_title(str(i) + ',' + str(j))
                ax1.axis('off')
                ax1.set_xticklabels([])
                ax1.set_yticklabels([])
      
        plt.tight_layout()
        plt.show()

    def plot_weights(layer, index, single_channel = True, collated = False):
        
        #checking whether the layer is convolution layer or not 
        if isinstance(layer, nn.Conv2d):
          #getting the weight tensor data
          weight_tensor = layer.weight.data
          if single_channel:
              plot_filters_single_channel(weight_tensor, index)
        else:
          print("Can only visualize layers which are convolutional")
    
    def prune_conv(conv, amount=0.2):
        #weight = conv.weight.detach().cpu().numpy()
        #out_channels = weight.shape[0]
        #L1_norm = np.sum( np.abs(weight), axis=(1,2,3))
        #num_pruned = int(out_channels * pruned_prob)
        #pruning_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
        strategy = tp.strategy.L1Strategy()
        n = len(conv.weight)
        # print('****************n:',len(conv.weight))
        n_to_prune = int(amount*n)
        # print('******************n_to_prune:',n_to_prune)
        if n_to_prune > n:
          n_to_prune = n-1
        pruning_index = strategy(conv.weight, amount=amount)
        
        # print('pruning index:', pruning_index)
        #visualize weights for alexnet - first conv layer
        # plot_weights(conv, pruning_index, single_channel = True)      
        
        plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
        # print('******pruning plan*******')
        # print(plan)
        plan.exec()
    
    # block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
    blk_id = 0
    counter = 0
    n = args.prune_rate # remaining rate, <= 0.97 
    block_total = 0
    if args.prune_method == 'constant':
      p_tuple = [n for i in range(17)]
    elif args.prune_method == 'linear':
      p_tuple = [0,n,n,n*2,n*2,n*2,n*3,n*3,n*3,n*3,n*4,n*4,n*4,n*5,n*5,n*5,n*6]
    else:
      n = 1 - n
      p_tuple = [1,n,n,n**2,n**2,n**2,n**3,n**3,n**3,n**3,n**4,n**4,n**4,n**5,n**5,n**5,n**6]
      p_tuple = [1-x for x in p_tuple] # pruning rate
    
    for m in model.modules():
        if isinstance( m, mobilenetv2.Block ):
          block_total += 1
        # print('M number is:',model.children())
    for m in model.modules():
        if isinstance( m, mobilenetv2.Block ):
          # constant
    
          # exponential pruning
          #  p_rate = p_tuple[counter]
            
          # linear pruning              
            prune_conv( m.conv1, p_tuple[counter])
            prune_conv( m.conv2, p_tuple[counter])
            counter+= 1
    print("Num of blocks: %d"%counter)
    return model    
コード例 #16
0
        return nparams_to_prune


# function wrapper
def my_pruning_fn(layer: CustomizeLayer,
                  idxs: list,
                  inplace: bool = True,
                  dry_run: bool = False):
    return MyPruningFn.apply(layer, idxs, inplace, dry_run)


model = FullyConnectedNet(128, 10, 256)
# pruning according to L1 Norm
strategy = tp.strategy.L1Strategy()  # or tp.strategy.RandomStrategy()

DG = tp.DependencyGraph()
# Register your customized layer
DG.register_customized_layer(
    CustomizeLayer,
    in_ch_pruning_fn=my_pruning_fn,  # prune channels/dimensions of input tensor
    out_ch_pruning_fn=
    my_pruning_fn,  # prune channels/dimensions of output tensor
    get_in_ch_fn=lambda l: l.
    in_dim,  # estimate the n_channel of input tensor. You can return None if the layer does not change tensor shape.
    get_out_ch_fn=lambda l: l.in_dim
)  # estimate the n_channel of output tensor. You can return None if the layer does not change tensor shape.

# Build dependency graph
DG.build_dependency(model, example_inputs=torch.randn(1, 128))
# get a pruning plan according to the dependency graph. idxs is the indices of pruned filters.
pruning_plan = DG.get_pruning_plan(model.fc1,
コード例 #17
0
ファイル: HardPruning.py プロジェクト: hkerma/deeplearning
 def init_dependency_graph(self, model, input_w=32, input_h=32):
     DG = pruning.DependencyGraph(self.model,
                                  fake_input=torch.randn(
                                      1, 3, input_w, input_h).cuda())
     return DG
コード例 #18
0
def prune_model(model):
    model.cpu()
    DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
    def prune_conv(conv, pruned_prob):
        weight = conv.weight.detach().cpu().numpy()
        out_channels = weight.shape[0]
コード例 #19
0
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

import torch
from torchvision.models import resnet18
import torch_pruning as pruning

model = resnet18(pretrained=True)
# build layer dependency for resnet18
DG = pruning.DependencyGraph(model, fake_input=torch.randn(1, 3, 224, 224))
# get a pruning plan according to the dependency graph of resnet18
pruning_plan = DG.get_pruning_plan(model.conv1,
                                   pruning.prune_conv,
                                   idxs=[2, 6, 9])
print(pruning_plan)
# execute this plan
pruning_plan.exec()

print(model)
コード例 #20
0
def channel_prune(ori_model,
                  example_inputs,
                  output_transform,
                  pruned_prob=0.3,
                  thres=None):
    model = copy.deepcopy(ori_model)
    model.cpu().eval()

    prunable_module_type = (nn.BatchNorm2d)

    ignore_idx = [230, 260, 290]

    prunable_modules = []
    for i, m in enumerate(model.modules()):
        if i in ignore_idx:
            continue
        if isinstance(m, prunable_module_type):
            prunable_modules.append(m)
    ori_size = tp.utils.count_params(model)
    DG = tp.DependencyGraph().build_dependency(
        model,
        example_inputs=example_inputs,
        output_transform=output_transform)
    bn_val, max_val = bn_analyze(prunable_modules,
                                 "render_img/before_pruning.jpg")
    if thres is None:
        thres_pos = int(pruned_prob * len(bn_val))
        thres_pos = min(thres_pos, len(bn_val) - 1)
        thres_pos = max(thres_pos, 0)
        thres = bn_val[thres_pos]
    print("Min val is %f, Max val is %f, Thres is %f" %
          (bn_val[0], bn_val[-1], thres))

    for layer_to_prune in prunable_modules:
        # select a layer
        weight = layer_to_prune.weight.data.detach().cpu().numpy()
        if isinstance(layer_to_prune, nn.Conv2d):
            if layer_to_prune.groups > 1:
                prune_fn = tp.prune_group_conv
            else:
                prune_fn = tp.prune_conv
            L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3))
        elif isinstance(layer_to_prune, nn.BatchNorm2d):
            prune_fn = tp.prune_batchnorm
            L1_norm = np.abs(weight)

        pos = np.array([i for i in range(len(L1_norm))])
        pruned_idx_mask = L1_norm < thres
        prun_index = pos[pruned_idx_mask].tolist()
        if len(prun_index) == len(L1_norm):
            del prun_index[np.argmax(L1_norm)]

        plan = DG.get_pruning_plan(layer_to_prune, prune_fn, prun_index)
        plan.exec()

    bn_analyze(prunable_modules, "render_img/after_pruning.jpg")

    with torch.no_grad():

        out = model(example_inputs)
        if output_transform:
            out = output_transform(out)
        print("  Params: %s => %s" % (ori_size, tp.utils.count_params(model)))
        if isinstance(out, (list, tuple)):
            for o in out:
                print("  Output: ", o.shape)
        else:
            print("  Output: ", out.shape)
        print("------------------------------------------------------\n")
    return model