def prune_model(model): model.cpu() DG = tp.DependencyGraph().build_dependency(model, torch.randn(1, 30, 3, 250, 250)) def prune_conv(conv, pruned_prob): weight = conv.weight.detach().cpu().numpy() out_channels = weight.shape[0] L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3)) num_pruned = int(out_channels * pruned_prob) prune_index = np.argsort( L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm plan = DG.get_pruning_plan(conv, tp.prune_conv, prune_index) plan.exec() block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3] blk_id = 0 for m in model.modules(): if isinstance(m, nn.Conv2d): prune_fn = tp.prune_conv # elif isinstance(m, nn.BatchNorm2d): # prune_fn = tp.prune_batchnorm prune_conv(m.conv1, block_prune_probs[blk_id]) blk_id += 1 print(model) return model
def random_prune(model, example_inputs, output_transform): model.cpu().eval() prunable_module_type = ( nn.Conv2d, nn.BatchNorm2d ) prunable_modules = [ m for m in model.modules() if isinstance(m, prunable_module_type) ] ori_size = tp.utils.count_params( model ) DG = tp.DependencyGraph().build_dependency( model, example_inputs=example_inputs, output_transform=output_transform ) for layer_to_prune in prunable_modules: # select a layer if isinstance( layer_to_prune, nn.Conv2d ): prune_fn = tp.prune_conv elif isinstance(layer_to_prune, nn.BatchNorm2d): prune_fn = tp.prune_batchnorm ch = tp.utils.count_prunable_channels( layer_to_prune ) rand_idx = random.sample( list(range(ch)), min( ch//2, 10 ) ) plan = DG.get_pruning_plan( layer_to_prune, prune_fn, rand_idx) plan.exec() print(model) with torch.no_grad(): out = model( example_inputs ) if output_transform: out = output_transform(out) print(model_name) print( " Params: %s => %s"%( ori_size, tp.utils.count_params(model) ) ) print( " Output: ", out.shape ) print("------------------------------------------------------\n")
def prune_lottery_ticket(model, initial_model, pruning_indices, percentages, BasicBlock, p=1, H=1024, W=2048): initial_model.cpu() model.cpu() prune_stats = [] def prune_conv(conv, idx, amount=0.2, p=1): strategy = tp.prune.strategy.LNStrategy(p) pruning_index = strategy.apply(conv.weight, amount) n_to_prune = max(int(amount * len(conv.weight)), 1) if isinstance(pruning_index, int): pruning_index = [pruning_index] stats = f"Layer {conv}, number to prune: {n_to_prune}, indices to prune: {pruning_index}" print(stats) prune_stats.append(stats) pruning_indices[idx] = (pruning_indices[idx][0], merge_pruning_indices(pruning_indices[idx][0], pruning_indices[idx][1], pruning_index)) #plan = DG.get_pruning_plan(conv_initial, tp.prune_conv, pruning_indices[idx][1]) #plan.exec() i = 0 j = 0 for m in model.modules(): if isinstance(m, BasicBlock): percentage = percentages[j] i += 2 j += 1 if percentage == 0: continue prune_conv(m.conv1, i - 2, percentage, p) prune_conv(m.conv2, i - 1, percentage, p) del model DG = tp.DependencyGraph().build_dependency(initial_model, torch.randn(1, 3, H, W)) idx = 0 for m in initial_model.modules(): if isinstance(m, BasicBlock): plan = DG.get_pruning_plan(m.conv1, tp.prune_conv, pruning_indices[idx][1]) plan.exec() plan = DG.get_pruning_plan(m.conv2, tp.prune_conv, pruning_indices[idx + 1][1]) plan.exec() idx += 2 return initial_model, prune_stats
def hardprune_f(model, calc_prun_cand, pr_percentage=0.1): DG = pruning.DependencyGraph(model, fake_input=torch.randn(1, 3, 32, 32)) # get a pruning plan according to the dependency graph. idxs is the indices of pruned filters. pruning_plan_list = calc_prun_cand(model, DG, pr_percentage) # execute this plan (prune the model) for i, pp in enumerate(pruning_plan_list): # print("=========== ", i, " ============ ") # print(pp) pp.exec()
def prune_model(model): model.cpu() DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 112, 112) ) def prune_conv(conv, Norm_pruned_percent, GM_pruned_percent): weight = conv.weight.detach().cpu() out_channels = weight.numpy().shape[0] # Norm Criteria (L2) # for L1 Norm use: norm = np.sum( np.abs(weight.numpy()), axis=(1,2,3)) norm = np.sum( np.square(weight.numpy()), axis=(1,2,3)) # no need to do square root, it is only for sorting num_norm_pruned = int(out_channels * Norm_pruned_percent) prune_index = np.argsort(norm)[:num_norm_pruned].tolist() # remove filters with small L2-Norm # Distance to GM Criteria (Prune layers closest to Geometric Distance) num_GM_pruned = int(out_channels * GM_pruned_percent) # indices of unprunned layers large_norm_index = [] large_norm_index = np.argsort(norm)[num_norm_pruned:] # based on norm calculated in "Norm Criteria" indices = torch.LongTensor(large_norm_index).cuda() # isolate layer left layer indices weight_vec = weight.view(weight.size()[0], -1) weight_after_norm_prune = torch.index_select(weight_vec.cuda(), 0, indices).cpu().numpy() # Calculate distance matrix # for euclidean distance distance_matrix = distance.cdist(weight_after_norm_prune, weight_after_norm_prune, 'euclidean') # for cos similarity # distance_matrix = 1 - distance.cdist(weight_vec, weight_vecs, 'cosine') distance_sum = np.sum(np.abs(distance_matrix), axis=0) # for distance similar: get the filter index with largest similarity == small distance sorted_distances_index = distance_sum.argsort()[: num_GM_pruned] prune_index_GM = [large_norm_index[i] for i in sorted_distances_index] total_prune_index = prune_index + list(prune_index_GM) print("norm:", prune_index) print("GM:", prune_index_GM) print("total:", total_prune_index) plan = DG.get_pruning_plan(conv, tp.prune_conv, total_prune_index) plan.exec() block_prune_probs = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.40, 0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49] blk_id = 0 for name, m in model.named_modules(): if isinstance( m, nn.Conv2d ): print(name) prune_conv( m, 0.2, 0.2 ) return model
def prune_model(name='', model=None, dir_models='', suffix='_pruned', im_size=224): print('\nPruning Model: ' + name + '...', end='\t') model.to(device) strategy = tp.strategy.L1Strategy() DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn( 1, 3, im_size, im_size)) def prune_conv(conv, amount=0.2): pruning_index = strategy(conv.weight, amount=amount) plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index) plan.exec() block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3] # block_prune_probs = [0.2, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3] limit = len(block_prune_probs) - 1 blk_id = 0 conv_index = 0 prev_conv = False for m in list(model.modules())[1:]: if isinstance(m, modules.Conv2d): prev_conv = True conv_index += 1 if isinstance(m, modules.Linear) and prev_conv: break for m in list(model.modules()): if isinstance(m, modules.Conv2d) and 'resnet' not in name.lower(): if 'naber' not in name.lower() or ('naber' in name.lower() and conv_index > 1): prune_conv(m) conv_index -= 1 if isinstance(m, BasicBlock) or isinstance(m, Bottleneck): prune_conv(m.conv1, block_prune_probs[blk_id]) prune_conv(m.conv2, block_prune_probs[blk_id]) if blk_id < limit - 1: blk_id += 1 print('COMPLETE') # 5. Save Model print('Saving', name + suffix + '...', end='\t') filename = dir_models + name + suffix + ".pth" torch.save(model, filename) print('COMPLETE') return torch.load(filename)
def prune_model(model, prune_prob = 0.1): model.cpu() DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) ) def prune_conv(conv, amount=0.2): strategy = tp.strategy.L1Strategy() pruning_index = strategy(conv.weight, amount=amount) plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index) plan.exec() for _, m in model.named_modules(): if isinstance(m, torch.nn.Conv2d): prune_conv(m, prune_prob) return model
def __call__(self, model, rate=0.1, example_inputs=None): if example_inputs is None: example_inputs = torch.randn(1, 3, 256, 256) DG = tp.DependencyGraph() DG.build_dependency(model, example_inputs=example_inputs) prunable_layers = [] total_params = 0 num_accumulative_conv_params = [ 0, ] for m in model.modules(): if isinstance(m, _PRUNABLE_MODULES): nparam = tp.utils.count_prunable_params(m) total_params += nparam if isinstance(m, (nn.modules.conv._ConvNd, nn.Linear)): prunable_layers.append(m) num_accumulative_conv_params.append( num_accumulative_conv_params[-1] + nparam) prunable_layers.pop(-1) # remove the last layer num_accumulative_conv_params.pop(-1) # remove the last layer num_conv_params = num_accumulative_conv_params[-1] num_accumulative_conv_params = [ (num_accumulative_conv_params[i], num_accumulative_conv_params[i + 1]) for i in range(len(num_accumulative_conv_params) - 1) ] def map_param_idx_to_conv_layer(i): for l, accu in zip(prunable_layers, num_accumulative_conv_params): if accu[0] <= i and i < accu[1]: return l num_pruned = 0 while num_pruned < total_params * rate: layer_to_prune = map_param_idx_to_conv_layer( random.randint(0, num_conv_params - 1)) if layer_to_prune.weight.shape[0] < 1: continue idx = self.select(layer_to_prune) fn = tp.prune_conv if isinstance( layer_to_prune, nn.modules.conv._ConvNd) else tp.prune_linear plan = DG.get_pruning_plan(layer_to_prune, fn, idxs=idx) num_pruned += plan.exec() return model
def prune_model(model,num_list): model.to(device) DG = tp.DependencyGraph().build_dependency(model, torch.randn(1, 3, 224, 224) ) def prune_bn(bn, num): L1_norm = bn.weight.detach().cpu().numpy() prune_index = np.argsort(L1_norm)[:num].tolist() # remove filters with small L1-Norm plan = DG.get_pruning_plan(bn, tp.prune_batchnorm, prune_index) plan.exec() blk_id = 0 for m in model.modules(): if isinstance( m, torchvision.models.resnet.Bottleneck ): prune_bn( m.bn1, num_list[blk_id] ) prune_bn( m.bn2, num_list[blk_id+1] ) blk_id+=2 return model
def prune_model(model): model.cpu() DG = tp.DependencyGraph().build_dependency(model, torch.randn(1, 3, 32, 32)) def prune_conv(conv, amount=0.2): strategy = tp.strategy.L1Strategy() pruning_index = strategy(conv.weight, amount=amount) plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index) plan.exec() block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3] blk_id = 0 for m in model.modules(): if isinstance(m, BasicBlock): prune_conv(m.conv1, block_prune_probs[blk_id]) prune_conv(m.conv2, block_prune_probs[blk_id]) blk_id += 1 return model
def prune_model(model): # model.cpu() # DG = pruning.DependencyGraph().build_dependency( model, torch.randn(1, 3, 224, 224) ) # def prune_conv(conv, num_pruned): # weight = conv.weight.detach().cpu().numpy() # L1_norm = np.abs(weight) # prune_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm # plan = DG.get_pruning_plan(conv, pruning.prune_batchnorm, prune_index) # # print(plan) # plan.exec() # blk_id = 0 # for m in model.modules(): # #print(m) # if isinstance(m, torchvision.models.resnet.Bottleneck): # prune_conv(m.bn1, 10) # prune_conv(m.bn2, 10) # return model model.cpu() DG = pruning.DependencyGraph().build_dependency( model, torch.randn(1, 3, 224, 224)) def prune_conv(conv, pruned_prob): weight = conv.weight.detach().cpu().numpy() out_channels = weight.shape[0] L1_norm = np.abs(weight) # L1_norm = np.sum(np.abs(weight), axis=(1,2,3)) num_pruned = int(out_channels * pruned_prob) prune_index = np.argsort( L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm plan = DG.get_pruning_plan(conv, pruning.prune_batchnorm, prune_index) plan.exec() block_prune_probs = [0.8] * 16 blk_id = 0 for m in model.modules(): if isinstance(m, torchvision.models.resnet.Bottleneck): prune_conv(m.bn1, block_prune_probs[blk_id]) prune_conv(m.bn2, block_prune_probs[blk_id]) blk_id += 1 return model
def prune_model(model): model.cpu() DG = pruning.DependencyGraph(model, fake_input=torch.randn(1, 3, 32, 32)) def prune_conv(conv, pruned_prob): weight = conv.weight.detach().cpu().numpy() out_channels = weight.shape[0] L1_norm = np.sum(weight, axis=(1, 2, 3)) num_pruned = int(out_channels * pruned_prob) prune_index = np.argsort( L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm plan = DG.get_pruning_plan(conv, pruning.prune_conv, prune_index) plan.exec() block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3] blk_id = 0 for m in model.modules(): if isinstance(m, resnet.BasicBlock): prune_conv(m.conv1, block_prune_probs[blk_id]) prune_conv(m.conv2, block_prune_probs[blk_id]) blk_id += 1 return model
def prune_model_structured(model, device, amt): model.cpu() dummy_input = torch.randn(1, 3, 32, 32) DG = tp.DependencyGraph().build_dependency(model.module, dummy_input) def prune_conv(conv, amount=0.015): strategy = tp.strategy.L1Strategy() pruning_index = strategy(conv.weight, amount=amount) plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index) plan.exec() block_id = 0 for m in model.module.modules(): if isinstance(m, Bottleneck): try: prune_conv(m.conv1, amt) prune_conv(m.conv2, amt) prune_conv(m.conv3, amt) except: continue block_id += 1 return model.to(device)
print("Average Inference time in ms:", average_inference_time) inference_time_matrix.append(average_inference_time) # This only make the prune permanant, the size of the model remains the same # The pruned weight are set to be zero. This won't affect the inference time # This is required to make the pruning package work. for module in conv_modules: prune.remove(module, 'weight') print("Number of params after removing mask:", get_num_params(conv_modules)) for module in conv_modules: filter_sum_weight = module.weight.sum( dim=(1, 2, 3)).detach().cpu().numpy() pruning_idxs = np.where(filter_sum_weight == 0)[0].tolist() # print(pruning_idxs) DG = pruning.DependencyGraph() DG.build_dependency(model_pruned, example_inputs=torch.randn(1, 3, 224, 224)) pruning_plan = DG.get_pruning_plan(module, pruning.prune_conv, idxs=pruning_idxs) pruning_plan.exec() print("==================== Actual pruned Model ====================") num_params = get_num_params(conv_modules) print("Number of params:", num_params) num_params_list[3].append(num_params) print("Model:") print(model_pruned) model_pruned.to(DEVICE)
def prune_model(model): '''Prune the model. Prunes the model using Torch_pruning toolkit. Produces the accuracy after pruning and the number of pruned parameters. Args: model: a instance of class ResNet18 whose parameters has been trained ''' DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) ) def plot_filters_single_channel(t, index): #kernels depth * number of kernels nplots = t.shape[0]*t.shape[1] ncolumns = 12 nrows = 1 + nplots//ncolumns #convert tensor to numpy image npimage = np.array(t.numpy(), np.float32) count = 0 fig = plt.figure(figsize=(ncolumns, nrows)) #looping through all the kernels in each channel for i in range(t.shape[0]): if i not in index: continue for j in range(t.shape[1]): count += 1 ax1 = fig.add_subplot(nrows, ncolumns, count) npimage = np.array(t[i, j].numpy(), np.float32) npimage = (npimage - np.mean(npimage)) / np.std(npimage) npimage = np.minimum(1, np.maximum(0, (npimage + 0.5))) ax1.imshow(npimage) ax1.set_title(str(i) + ',' + str(j)) ax1.axis('off') ax1.set_xticklabels([]) ax1.set_yticklabels([]) plt.tight_layout() plt.show() def plot_weights(layer, index, single_channel = True, collated = False): #checking whether the layer is convolution layer or not if isinstance(layer, nn.Conv2d): #getting the weight tensor data weight_tensor = layer.weight.data if single_channel: plot_filters_single_channel(weight_tensor, index) else: print("Can only visualize layers which are convolutional") def prune_conv(conv, amount=0.2): #weight = conv.weight.detach().cpu().numpy() #out_channels = weight.shape[0] #L1_norm = np.sum( np.abs(weight), axis=(1,2,3)) #num_pruned = int(out_channels * pruned_prob) #pruning_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm strategy = tp.strategy.L1Strategy() n = len(conv.weight) # print('****************n:',len(conv.weight)) n_to_prune = int(amount*n) # print('******************n_to_prune:',n_to_prune) if n_to_prune > n: n_to_prune = n-1 pruning_index = strategy(conv.weight, amount=amount) # print('pruning index:', pruning_index) #visualize weights for alexnet - first conv layer # plot_weights(conv, pruning_index, single_channel = True) plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index) # print('******pruning plan*******') # print(plan) plan.exec() # block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3] blk_id = 0 counter = 0 n = args.prune_rate # remaining rate, <= 0.97 block_total = 0 if args.prune_method == 'constant': p_tuple = [n for i in range(17)] elif args.prune_method == 'linear': p_tuple = [0,n,n,n*2,n*2,n*2,n*3,n*3,n*3,n*3,n*4,n*4,n*4,n*5,n*5,n*5,n*6] else: n = 1 - n p_tuple = [1,n,n,n**2,n**2,n**2,n**3,n**3,n**3,n**3,n**4,n**4,n**4,n**5,n**5,n**5,n**6] p_tuple = [1-x for x in p_tuple] # pruning rate for m in model.modules(): if isinstance( m, mobilenetv2.Block ): block_total += 1 # print('M number is:',model.children()) for m in model.modules(): if isinstance( m, mobilenetv2.Block ): # constant # exponential pruning # p_rate = p_tuple[counter] # linear pruning prune_conv( m.conv1, p_tuple[counter]) prune_conv( m.conv2, p_tuple[counter]) counter+= 1 print("Num of blocks: %d"%counter) return model
return nparams_to_prune # function wrapper def my_pruning_fn(layer: CustomizeLayer, idxs: list, inplace: bool = True, dry_run: bool = False): return MyPruningFn.apply(layer, idxs, inplace, dry_run) model = FullyConnectedNet(128, 10, 256) # pruning according to L1 Norm strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy() DG = tp.DependencyGraph() # Register your customized layer DG.register_customized_layer( CustomizeLayer, in_ch_pruning_fn=my_pruning_fn, # prune channels/dimensions of input tensor out_ch_pruning_fn= my_pruning_fn, # prune channels/dimensions of output tensor get_in_ch_fn=lambda l: l. in_dim, # estimate the n_channel of input tensor. You can return None if the layer does not change tensor shape. get_out_ch_fn=lambda l: l.in_dim ) # estimate the n_channel of output tensor. You can return None if the layer does not change tensor shape. # Build dependency graph DG.build_dependency(model, example_inputs=torch.randn(1, 128)) # get a pruning plan according to the dependency graph. idxs is the indices of pruned filters. pruning_plan = DG.get_pruning_plan(model.fc1,
def init_dependency_graph(self, model, input_w=32, input_h=32): DG = pruning.DependencyGraph(self.model, fake_input=torch.randn( 1, 3, input_w, input_h).cuda()) return DG
def prune_model(model): model.cpu() DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) ) def prune_conv(conv, pruned_prob): weight = conv.weight.detach().cpu().numpy() out_channels = weight.shape[0]
import sys, os sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) import torch from torchvision.models import resnet18 import torch_pruning as pruning model = resnet18(pretrained=True) # build layer dependency for resnet18 DG = pruning.DependencyGraph(model, fake_input=torch.randn(1, 3, 224, 224)) # get a pruning plan according to the dependency graph of resnet18 pruning_plan = DG.get_pruning_plan(model.conv1, pruning.prune_conv, idxs=[2, 6, 9]) print(pruning_plan) # execute this plan pruning_plan.exec() print(model)
def channel_prune(ori_model, example_inputs, output_transform, pruned_prob=0.3, thres=None): model = copy.deepcopy(ori_model) model.cpu().eval() prunable_module_type = (nn.BatchNorm2d) ignore_idx = [230, 260, 290] prunable_modules = [] for i, m in enumerate(model.modules()): if i in ignore_idx: continue if isinstance(m, prunable_module_type): prunable_modules.append(m) ori_size = tp.utils.count_params(model) DG = tp.DependencyGraph().build_dependency( model, example_inputs=example_inputs, output_transform=output_transform) bn_val, max_val = bn_analyze(prunable_modules, "render_img/before_pruning.jpg") if thres is None: thres_pos = int(pruned_prob * len(bn_val)) thres_pos = min(thres_pos, len(bn_val) - 1) thres_pos = max(thres_pos, 0) thres = bn_val[thres_pos] print("Min val is %f, Max val is %f, Thres is %f" % (bn_val[0], bn_val[-1], thres)) for layer_to_prune in prunable_modules: # select a layer weight = layer_to_prune.weight.data.detach().cpu().numpy() if isinstance(layer_to_prune, nn.Conv2d): if layer_to_prune.groups > 1: prune_fn = tp.prune_group_conv else: prune_fn = tp.prune_conv L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3)) elif isinstance(layer_to_prune, nn.BatchNorm2d): prune_fn = tp.prune_batchnorm L1_norm = np.abs(weight) pos = np.array([i for i in range(len(L1_norm))]) pruned_idx_mask = L1_norm < thres prun_index = pos[pruned_idx_mask].tolist() if len(prun_index) == len(L1_norm): del prun_index[np.argmax(L1_norm)] plan = DG.get_pruning_plan(layer_to_prune, prune_fn, prun_index) plan.exec() bn_analyze(prunable_modules, "render_img/after_pruning.jpg") with torch.no_grad(): out = model(example_inputs) if output_transform: out = output_transform(out) print(" Params: %s => %s" % (ori_size, tp.utils.count_params(model))) if isinstance(out, (list, tuple)): for o in out: print(" Output: ", o.shape) else: print(" Output: ", out.shape) print("------------------------------------------------------\n") return model