def prune_actor(self, ratios, dims): if len(ratios) != 2: raise ValueError("length of ratios not matching critic number of layers") if len(dims) != 2: raise ValueError("length of ratios not matching critic number of layers") prune.ln_structured(self.actor[0], "weight", amount=ratios[0], n=1, dim=dims[0]) prune.ln_structured(self.actor[2], "weight", amount=ratios[1], n=1, dim=dims[0])
def prune_step(model, a1=0.01, a2=0.01, conv_group=True): for name, module in model.named_modules(): if isinstance(module, torch.nn.Conv2d) and conv_group: prune.ln_structured(module, name='weight', amount=a1, n=2, dim=0) prune.ln_structured(module, name='weight', amount=a1, n=2, dim=1) #prune.remove(module, name='weight') elif isinstance(module, torch.nn.Linear) or isinstance( module, torch.nn.Conv2d): prune.l1_unstructured(module, name='weight', amount=a2)
def prune_module(module, method, amount): if method == "ln": prune.ln_structured(module, name="weight", amount=amount, n=2, dim=0) elif method == "l1": prune.l1_unstructured(module, name="weight", amount=amount) else: raise ValueError(f"{method} is wrong")
def prune_transformer_block(transformer_block, args): pruning_amount = float(args.pruning_amount) prune.ln_structured(transformer_block.fc1, name='weight', amount=pruning_amount, n=0, dim=0) prune.remove(transformer_block.fc1, 'weight') prune.ln_structured(transformer_block.fc2, name='weight', amount=pruning_amount, n=0, dim=0) prune.remove(transformer_block.fc2, 'weight') for sub_module in transformer_block.fc_delta: if isinstance(sub_module, torch.nn.Linear): prune.ln_structured(sub_module, name='weight', amount=pruning_amount, n=0, dim=0) prune.remove(sub_module, 'weight') for sub_module in transformer_block.fc_gamma: if isinstance(sub_module, torch.nn.Linear): prune.ln_structured(sub_module, name='weight', amount=pruning_amount, n=0, dim=0) prune.remove(sub_module, 'weight') return transformer_block
def prune(self): prune.ln_structured(self.conv1, name='weight', amount=self.sparsity, n=1, dim=0) prune.ln_structured(self.conv2, name='weight', amount=self.sparsity, n=1, dim=0) prune.ln_structured(self.conv3, name='weight', amount=self.sparsity, n=1, dim=0) if self.se != None: for c in self.se.se: if isinstance(c, nn.Conv2d): prune.ln_structured(c, name='weight', amount=self.sparsity, n=1, dim=0)
def prune_all(self): for layer_idx in range(self.nb_layers): conv = eval(f"self.model.conv{layer_idx+1}") bn = eval(f"self.model.bn{layer_idx+1}") prune.ln_structured(module=conv, name='weight', amount=self.amount, n=self.norme, dim=self.dim) prune.l1_unstructured(module=bn, name='weight', amount=self.amount) prune.ln_structured(module=self.model.fc1, name='weight', amount=self.amount, n=self.norme, dim=self.dim)
def prune_all(self): prune.ln_structured(module=self.model.conv1, name='weight', amount=self.amount, n=self.norme, dim=self.dim) prune.l1_unstructured(module=self.model.bn1, name='weight', amount=self.amount) self.prune_block(self.model.layer1) self.prune_block(self.model.layer2) self.prune_block(self.model.layer3) prune.ln_structured(module=self.model.linear, name='weight', amount=self.amount, n=self.norme, dim=self.dim)
def prune_model(model): remove_amount = total_prune_amount / (max_epochs * 10) print(f'pruned model by {remove_amount}') if prune_type == 'global_unstructured': parameters_to_prune = [(layer, 'weight') for layer in conv_layers] prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=remove_amount, ) else: for layer in conv_layers: prune.ln_structured(layer, name='weight', amount=remove_amount, n=1, dim=0)
def channel_pruning(old_model, pruning_option="l1", name="weight", prune_ratio=0.5): alive_weight_index = [] # Get alived channel index via L1 Structured Pruning if pruning_option == "l1": for _, old_module in old_model.named_modules(): if isinstance(old_module, torch.nn.Conv2d): prune.ln_structured(old_module, name=name, amount=prune_ratio, n=1, dim=0) alive_index = alive_channel_index(old_module) alive_weight_index.append(alive_index) # Get alived channel index via L2 Structured Pruning elif pruning_option == "l2": for _, old_module in old_model.named_modules(): if isinstance(old_module, torch.nn.Conv2d): prune.ln_structured(old_module, name=name, amount=prune_ratio, n=2, dim=0) alive_index = alive_channel_index(old_module) alive_weight_index.append(alive_index) # Get alived cannel index via random Structured Pruning elif pruning_option == "random": for _, old_module in old_model.named_modules(): if isinstance(old_module, nn.Conv2d): num_out_channel = old_module.weight.data.shape[0] num_pruned_channel = int(prune_ratio * num_out_channel) num_alive_channel = num_out_channel - num_pruned_channel alive_index = random.sample(range(num_out_channel), num_alive_channel) alive_weight_index.append(alive_index) return alive_weight_index
def get_pruned_model(model, normed=False, amount=0.5): model_pruned = deepcopy(model) for i in range(len(model)): if isinstance(model_pruned[i], nn.Conv2d): prune.ln_structured(model_pruned[i], name='weight', amount=amount, n=2, dim=0) if normed: model_pruned[i].weight = model_pruned[i].weight / (1 - amount) elif isinstance(model_pruned[i], nn.Linear) and i != len(model) - 2: prune.ln_structured(model_pruned[i], name='weight', amount=amount, n=2, dim=0) if normed: model_pruned[i].weight = model_pruned[i].weight / (1 - amount) return model_pruned
def prune_by_percentile(self, amount=5.0): ''' method to prune specified modules in layer to be pruned with percentile threshold ''' alive_parameters = [] for name, p in self.named_parameters(): if 'bias' in name or 'mask' in name: continue tensor = p.data.cpu().numpy() alive = tensor[np.nonzero(tensor)] alive_parameters.append(alive) all_alives = np.concatenate(alive_parameters) percentile_value = np.percentile(abs(all_alives), amount) logging.info(f'Pruning with threshold : {percentile_value}') for name, module in self.named_modules(): if name.contains('denselayer'): prune.ln_structured(module, name='weight', amount=percentile_value, n=2, dim=0)
def pruning_cp_fg(net, a_list): if not isinstance(net, nn.Module): print('Invalid input. Must be nn.Module') return newnet = copy.deepcopy(net) i = 0 for name, module in newnet.named_modules(): if isinstance(module, nn.Conv2d): # print("Sparsity ratio",a_list[i]) prune.ln_structured(module, name='weight', amount=float(1 - a_list[i]), n=2, dim=0) i += 1 if isinstance(module, nn.Linear): prune.l1_unstructured(module, name='weight', amount=float(1 - a_list[i])) i += 1 return newnet
def prune_block(self, sub_layer): for block_num, block in enumerate(sub_layer): prune.ln_structured(module=block.conv1, name='weight', amount=self.amount, n=self.norme, dim=self.dim) prune.l1_unstructured(module=block.bn1, name='weight', amount=self.amount) prune.ln_structured(module=block.conv2, name='weight', amount=self.amount, n=self.norme, dim=self.dim) prune.l1_unstructured(module=block.bn2, name='weight', amount=self.amount) for short_layer in block.shortcut: if isinstance(short_layer, torch.nn.modules.conv.Conv1d): prune.ln_structured(module=short_layer, name='weight', amount=self.amount, n=self.norme, dim=self.dim) elif isinstance(short_layer, torch.nn.modules.batchnorm.BatchNorm1d): prune.l1_unstructured(module=short_layer, name='weight', amount=self.amount)
def channel_pruning(net, a_list): ''' :param net: DNN :param a_list: pruning rate :return: newnet (nn.Module): a newnet contain mask that help prune network's weight ''' if not isinstance(net, nn.Module): print('Invalid input. Must be nn.Module') return newnet = copy.deepcopy(net) i = 0 for name, module in newnet.named_modules(): if isinstance(module, nn.Conv2d): prune.ln_structured(module, name='weight', amount=float(1 - a_list[i]), n=2, dim=0) i += 1 return newnet
def main(): config = '/media/shalev/98a3e66d-f664-402a-9639-15ec6b8a7150/work_dirs/try2/faster_rcnn_r50_caffe_c4_1x_coco_shalev.py' checkpoint = '/media/shalev/98a3e66d-f664-402a-9639-15ec6b8a7150/work_dirs/try2/latest.pth' src_img_path = '/home/shalev/downloads/1pic_coco/000000000285.jpg' dst_img_path = '/home/shalev/downloads/1pic_coco/000000000285_res.jpg' img = cv2.imread(src_img_path) model = mmdet.apis.init_detector(config, checkpoint=checkpoint, device='cuda:0') for i in range(10): if PRUNE: # backbone = model.backbone modules = [ model.backbone.children(), model.roi_head.children(), model.rpn_head.children() ] for main_module in modules: for module in main_module: if isinstance(module, torch.nn.Conv2d) or isinstance( module, torch.nn.Linear): print("before: ", module.weight.sum()) prune.ln_structured(module, name='weight', amount=0.05, dim=0, n=float('-inf')) print("after: ", module.weight.sum()) else: for sub in module.children(): if isinstance(sub, torch.nn.Conv2d) or isinstance( sub, torch.nn.Linear): print("before: ", sub.weight.sum()) prune.ln_structured(sub, name='weight', amount=0.15, dim=0, n=float('-inf')) print("after: ", sub.weight.sum()) else: for sub_sub in sub.children(): if isinstance( sub_sub, torch.nn.Conv2d) or isinstance( sub_sub, torch.nn.Linear): print("before: ", sub_sub.weight.sum()) prune.ln_structured(sub_sub, name='weight', amount=0.15, dim=0, n=float('-inf')) print("after: ", sub_sub.weight.sum()) start = time.time() res = mmdet.apis.inference_detector(model, img) print("Inference time: ", (time.time() - start)) if hasattr(model, 'module'): model = model.module img_res = model.show_result(img, res, score_thr=0.305, show=True)
def prune_critic(self, ratios, dims): if len(ratios) != 3: raise ValueError("length of ratios not matching critic number of layers") if len(dims) != 3: raise ValueError("length of ratios not matching critic number of layers") prune.ln_structured(self.critic[0], "weight", amount=ratios[0], n=1, dim=dims[0]) prune.ln_structured(self.critic[2], 'weight', amount=ratios[1], n=1, dim=dims[1]) prune.ln_structured(self.critic_linear, "weight", amount=ratios[2], n=2, dim=dims[2])
def _prune_res_unit5(self, ratio=0.1): prune.ln_structured(list(self.conv5_x)[0].conv1, name="weight", amount=ratio, n=1, dim=0) prune.ln_structured(list(self.conv5_x)[0].conv2, name="weight", amount=ratio, n=1, dim=0) prune.ln_structured(list(self.conv5_x)[1].conv1, name="weight", amount=ratio, n=1, dim=0) prune.ln_structured(list(self.conv5_x)[1].conv2, name="weight", amount=ratio, n=1, dim=0)
weight_bit_num=weight_bit_width) # end = time.time() # print(f'It takes {end-start:.6f} seconds.') net_gpu.load_state_dict( torch.load('weight_bit_' + str(weight_bit_width) + '_best.pth')) net_gpu.to(device) resume_acc = evaluate(net_gpu, xtest_gpu, ytest) print('quantization best accuracy: {:.5f}'.format(resume_acc)) #prune conv_module = net_gpu.conv2 prune.ln_structured(conv_module, name='weight', amount=amount_num, n=2, dim=0) # print(list(model.conv2.named_parameters())) # prune_acc = evaluate(net_gpu, xtest_gpu, ytest) print('prune accuracy: {:.5f}'.format(prune_acc)) fine_tune_epoch = 20 train_and_eval(xtrain_gpu, ytrain_gpu, net_gpu, xtest_gpu, ytest,
def pruning(model0, percentage, method): # copy a model0 for pruning # model0=copy.deepcopy(model.to(device)) if method == "unstructured": for name, module in model0.named_modules(): if isinstance(module, torch.nn.Embedding): prune.l1_unstructured(module, name='weight', amount=percentage) # prune lstm layers elif isinstance(module, torch.nn.LSTM): prune.l1_unstructured(module, name='weight_hh_l0', amount=percentage) prune.l1_unstructured(module, name='weight_ih_l0', amount=percentage) prune.l1_unstructured(module, name='weight_hh_l1', amount=percentage) prune.l1_unstructured(module, name='weight_ih_l1', amount=percentage) # prune linear layers elif isinstance(module, torch.nn.Linear): prune.l1_unstructured(module, name='weight', amount=percentage) elif method == "structured": for name, module in model0.named_modules(): if isinstance(module, torch.nn.Embedding): prune.ln_structured(module, name='weight', amount=percentage, n=1, dim=0) # prune lstm layers elif isinstance(module, torch.nn.LSTM): prune.ln_structured(module, name='weight_hh_l0', amount=percentage, n=1, dim=0) prune.ln_structured(module, name='weight_ih_l0', amount=percentage, n=1, dim=0) prune.ln_structured(module, name='weight_hh_l1', amount=percentage, n=1, dim=0) prune.ln_structured(module, name='weight_ih_l1', amount=percentage, n=1, dim=0) # prune linear layers elif isinstance(module, torch.nn.Linear): prune.ln_structured(module, name='weight', amount=percentage, n=1, dim=0) for name, module in model0.named_modules(): if isinstance(module, torch.nn.Embedding): prune.remove(module, 'weight') # prune lstm layers elif isinstance(module, torch.nn.LSTM): prune.remove(module, 'weight_hh_l0') prune.remove(module, 'weight_ih_l0') prune.remove(module, 'weight_hh_l1') prune.remove(module, 'weight_ih_l1') # prune linear layers elif isinstance(module, torch.nn.Linear): prune.remove(module, 'weight') test_data = TensorDataset(torch.Tensor(X), torch.Tensor(Y).long()) test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True) # model_prune=copy.deepcopy(model0.to(device)) model_prune = model0 optimizer = torch.optim.Adam(model_prune.parameters(), lr) model_ls = [] accuracy = [] for epoch in range(num_epochs_retrain): for i, (x, y) in enumerate(train_loader): x = x.reshape(-1, time_steps, input_size).to(device) y = y.to(device) # forward pass outputs = model_prune(x) loss = criterion(outputs, y) # backward and optimize optimizer.zero_grad() loss.backward() for name, param in model_prune.named_parameters(): if "weight" in name: param_data = param.data.cpu().numpy() param_grad = param.grad.data.cpu().numpy() param_grad = np.where(param_data < 0.00001, 0, param_grad) param.grad.data = torch.from_numpy(param_grad).to( device) optimizer.step() if i % 1000 == 0: print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( epoch + 1, num_epochs_retrain, i + 1, total_step, loss.item())) with torch.no_grad(): correct = 0 total = 0 times = 0 for x, y in test_loader: x = x.reshape(-1, time_steps, input_size).to(device) y = y.to(device) outputs = model_prune(x) prob = F.softmax(outputs) _, predicted = torch.max(prob, 1) total += y.size(0) correct += (predicted == y).sum().item() times = times + 1 if times > 100: break print( 'Test Accuracy of the model on the 10000 test x: {} %' .format(100 * correct / total)) model_ls.append(model_prune) accuracy.append(100 * correct / total) model_prune = model_ls[np.argmax(accuracy)] # test accuary test_data = TensorDataset(torch.Tensor(X), torch.Tensor(Y).long()) test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True) with torch.no_grad(): correct = 0 total = 0 times = 0 for x, y in test_loader: x = x.reshape(-1, time_steps, input_size).to(device) y = y.to(device) outputs = model_prune(x) prob = F.softmax(outputs) _, predicted = torch.max(prob, 1) total += y.size(0) correct += (predicted == y).sum().item() times = times + 1 if times > 100: break print('Test Accuracy of the model on the 10000 test x: {} %'.format( 100 * correct / total)) # quantize the pruned model quantized_model_prune = torch.quantization.quantize_dynamic( model_prune.to('cpu'), {nn.Embedding, nn.LSTM, nn.Linear}, dtype=torch.qint8) return model_prune, quantized_model_prune
def prune_model(method_name, parameters_to_prune, pruning_rate): if method_name == 'l1_unstructured': prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=pruning_rate, ) elif method_name == 'random': prune.global_unstructured( parameters_to_prune, pruning_method=prune.RandomUnstructured, amount=pruning_rate, ) elif method_name == 'l1_structured': for (module, name) in parameters_to_prune: prune.ln_structured(module=module, name=name, n=2, amount=pruning_rate, dim=-1) prune.ln_structured(module=module, name=name, n=2, amount=pruning_rate, dim=-1) prune.ln_structured(module=module, name=name, n=2, amount=pruning_rate, dim=-1) elif method_name == 'l2_structured': for (module, name) in parameters_to_prune: prune.ln_structured(module=module, name=name, n=2, amount=pruning_rate, dim=-1) prune.ln_structured(module=module, name=name, n=2, amount=pruning_rate, dim=-1) prune.ln_structured(module=module, name=name, n=2, amount=pruning_rate, dim=-1) else: raise ("Pruning method not found")
print("# conv1 pruned buffers") print(list(module.named_buffers())) # Prune weight using L1 norm and 3 smallest entries prune.l1_unstructured(module, name="bias", amount=3) print("# conv1 pruned bias params") print(list(module.named_parameters())) print("# conv1 pruned bias buffers") print(list(module.named_buffers())) print("# Forward pre hooks") print(module._forward_pre_hooks) # Iterative pruning (Prune multiple times in series, zeros out 50%) prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0) # weights pruned print(module.weight) for hook in module._forward_pre_hooks.values(): if hook._tensor_name == "weight": # select out the correct hook break # pruning history print(list(hook)) # Remove reparamatrization prune.remove(module, "weight") print(list(module.named_parameters())) # Prune multiple based on type (20% in conv and 40% in linear)
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_,reals,opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale+1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr,G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. if fine_tune: z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, warmup_steps) else: z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter) G_curr = functions.reset_grads(G_curr,False) # D_curr = functions.reset_grads(D_curr,False) G_curr.eval() # D_curr.eval() ################################################################################# # Visualzie weights def visualize_weights(modules, fig_name): ori_weights = torch.tensor([]).cuda() for m in modules: cur_params = m.weight.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) # cur_params = m.bias.data.flatten() # ori_weights = torch.cat((ori_weights, cur_params)) sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement()) print(sparsity, ori_weights.nelement()) ori_weights = ori_weights.cpu().numpy() ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100) plt.savefig("%s/%s.png" % (opt.outf, fig_name)) plt.close() # Pruning weights Structured or Non-structured if not structured: modules = [G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv, G_curr.body.block1.norm, G_curr.body.block2.conv, G_curr.body.block2.norm, G_curr.body.block3.conv, G_curr.body.block3.norm, G_curr.tail[0]] parameters_to_prune = ( (G_curr.head.conv, 'weight'), (G_curr.head.norm, 'weight'), (G_curr.body.block1.conv, 'weight'), (G_curr.body.block1.norm, 'weight'), (G_curr.body.block2.conv, 'weight'), (G_curr.body.block2.norm, 'weight'), (G_curr.body.block3.conv, 'weight'), (G_curr.body.block3.norm, 'weight'), (G_curr.tail[0], 'weight'), (G_curr.head.conv, 'bias'), (G_curr.head.norm, 'bias'), (G_curr.body.block1.conv, 'bias'), (G_curr.body.block1.norm, 'bias'), (G_curr.body.block2.conv, 'bias'), (G_curr.body.block2.norm, 'bias'), (G_curr.body.block3.conv, 'bias'), (G_curr.body.block3.norm, 'bias'), (G_curr.tail[0], 'bias'), ) visualize_weights(modules, 'ori') # Prune weights prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=pruning_amount, ) else: modules = [G_curr.head.conv, G_curr.body.block1.conv, G_curr.body.block2.conv, G_curr.body.block3.conv] visualize_weights(modules, 'ori') # pytorch_total_params = sum(p.numel() for p in G_curr.parameters()) # print(pytorch_total_params) for module in modules: m = prune.ln_structured(module, name="weight", amount=pruning_amount, n=1, dim=0) # m = prune.ln_structured(module, name="bias", amount=pruning_amount, n=1, dim=0) torch.save(G_curr.state_dict(), '%s/raw_prune_netG.pth' % (opt.outf)) visualize_weights(modules, 'raw-prune') if cur_scale_level > 0: fake_Gs = Gs.copy() fake_Gs.append(G_curr) fake_Zs = Zs.copy() fake_Zs.append(z_curr) fake_noise = NoiseAmp.copy() fake_noise.append(opt.noise_amp) fake_reals = reals[:cur_scale_level+1].copy() prune_SinGAN_generate(fake_Gs, fake_Zs, fake_reals, fake_noise, opt, gen_start_scale=0, num_samples=1, level=cur_scale_level) # Fine-tuning if fine_tune: G_curr = functions.reset_grads(G_curr, True) G_curr.train() if not structured: # Keep training using inherited weights z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter - warmup_steps, prune=True) else: # Training from scratch # G_curr.apply(models.weights_init) # D_curr.apply(models.weights_init) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter, prune=True) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() visualize_weights(modules, 'fine-tune') for m in modules: prune.remove(m, 'weight') if not structured: prune.remove(m, 'bias') # pytorch_total_params = sum(p.numel() for p in G_curr.parameters()) # print(pytorch_total_params) ################################################################################# Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level+=1 nfc_prev = opt.nfc del D_curr,G_curr return
Module.append(newModel.features[i]) frac = length - count pr.append(maxpr / frac) with open(outfile, 'a') as f: f.write("Global Sparsity: {:.2f}%".format(pr[count] * 100)) count += 1 """Select the amount of feature we want to prune in each Layer""" epochs = 10 max_lr = 1e-3 grad_clip = .2 weight_decay = 1e-5 L1 = 1e-5 itteration = 5 for ittr in range(itteration): for i in range(len(Module)): prune.ln_structured(Module[i], name="weight", amount=pr[i], n=1, dim=0) numberOfZero = 0 numberOfElements = 0 totalNumberOfZero = 0 totalNumberOfElements = 0 for i, j in zip(range(len(prunelist)), prunelist): numberOfZero = torch.sum(Module[i].weight == 0) totalNumberOfZero += numberOfZero numberOfElements = Module[i].weight.nelement() totalNumberOfElements += numberOfElements frac = 100. * float(torch.sum(Module[i].weight == 0)) / float( Module[i].weight.nelement()) with open(outfile, 'a') as f: f.write(f"\n {j} Sparsity in {Module[i]} is \t{frac}") with open(outfile, 'a') as f:
def prune_actor(self, ratios, dims): # if type(self.dist) is Categorical: prune.ln_structured(self.dist.linear, "weight", amount=ratios[-1], n=1, dim=dims[-1]) # if type(self.dist) is DiagGaussian: # prune.ln_structured(self.dist.fc_mean, "weight", amount=ratios[-1], n=1, dim=dims[-1]) self.base.prune_actor(ratios[:-1], dims[:-1])
def __init__(self, in_planes, planes, num_bits, num_bits_weight, stride, type_prune, sparsity, layer_num, option='A'): super(ResNetBlock, self).__init__() if in_planes == 3: op = QConv2d(in_planes, planes, num_bits, num_bits_weight, kernel_size=3, stride=1, padding=1, bias=False) if type_prune == 'channel': op = prune.ln_structured(op, name='weight', amount=sparsity, n=2, dim=0) elif type_prune == 'group': width = 4 tmp_pruned = op.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) append_size = width - tmp_pruned.shape[1] % width tmp_pruned = torch.cat( (tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width) tmp_pruned = tmp_pruned.pow(2.0).mean( 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() num = tmp.shape[0] * (1 - sparsity) top_k = torch.topk(tmp, int(num), sorted=True) threshold = top_k.values[-1] tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0:op.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) op = prune.custom_from_mask(op, name='weight', mask=tmp_pruned) self.add_module("conv", op) bn_op = nn.BatchNorm2d(planes) self.add_module("bn", bn_op) self.add_module("relu", nn.ReLU(inplace=True)) elif in_planes == 1: self.add_module("avg_pool", nn.AvgPool2d(kernel_size=8, stride=1)) self.add_module("flatten", Flatten()) op = nn.Linear(in_features=64, out_features=10) if type_prune == 'channel': op = prune.ln_structured(op, name='weight', amount=sparsity, n=2, dim=0) elif type_prune == 'group': width = 4 tmp_pruned = op.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) append_size = width - tmp_pruned.shape[1] % width tmp_pruned = torch.cat( (tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width) tmp_pruned = tmp_pruned.pow(2.0).mean( 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() num = tmp.shape[0] * (1 - sparsity) top_k = torch.topk(tmp, int(num), sorted=True) threshold = top_k.values[-1] tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0:op.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) op = prune.custom_from_mask(op, name='weight', mask=tmp_pruned) self.add_module("fc", op) else: op = BasicBlock(in_planes, planes, num_bits, num_bits_weight, stride, type_prune, sparsity, layer_num, option) self.add_module("conv", op)
def global_pruning(self,p_to_delete,dim=0): for target_module in self.target_modules: prune.ln_structured(target_module,name="weight",dim=dim,amount=p_to_delete,n=1) # dim est là où on veut supprimer poids (ligne : 1, col : 0?) Sur quelle dim c'est mieux de pruner?
def __init__(self, in_planes, planes, num_bits, num_bits_weight, stride, type_prune, sparsity, layer_num, option='A'): super(BasicBlock, self).__init__() self.conv1 = QConv2d(in_planes, planes, num_bits, num_bits_weight, kernel_size=3, stride=stride, padding=1, bias=False) if type_prune == 'channel': self.conv1 = prune.ln_structured(self.conv1, name='weight', amount=sparsity, n=2, dim=0) elif type_prune == 'group': width = 4 tmp_pruned = self.conv1.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) append_size = width - tmp_pruned.shape[1] % width tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width) tmp_pruned = tmp_pruned.pow(2.0).mean( 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() num = tmp.shape[0] * (1 - sparsity) top_k = torch.topk(tmp, int(num), sorted=True) threshold = top_k.values[-1] tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0:self.conv1.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) self.conv1 = prune.custom_from_mask(self.conv1, name='weight', mask=tmp_pruned) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = QConv2d(planes, planes, num_bits, num_bits_weight, kernel_size=3, stride=1, padding=1, bias=False) if type_prune == 'channel': self.conv2 = prune.ln_structured(self.conv2, name='weight', amount=sparsity, n=2, dim=0) elif type_prune == 'group': width = 4 tmp_pruned = self.conv2.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) append_size = width - tmp_pruned.shape[1] % width tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width) tmp_pruned = tmp_pruned.pow(2.0).mean( 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() num = tmp.shape[0] * (1 - sparsity) top_k = torch.topk(tmp, int(num), sorted=True) threshold = top_k.values[-1] tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0:self.conv2.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) self.conv2 = prune.custom_from_mask(self.conv2, name='weight', mask=tmp_pruned) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: if option == 'A': """ For CIFAR10 ResNet paper uses option A. """ self.shortcut = LambdaLayer(lambda x: F.pad( x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0)) elif option == 'B': self.shortcut = nn.Sequential( QConv2d(in_planes, self.expansion * planes, num_bits, num_bits_weight, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes))
def __init__(self, args, config): # Init arguments self.args = args self.config = config self.config.experiment_start_time = datetime.now().strftime('%Y%m%d_%H%M%S') self.config.experiment_name = self.experiment_name if not self.args.dry_run: self.checkpoints_dir_path = os.path.join(self.output_dir_path, 'checkpoints') self.setup_experiment_output() self.logger = Logger(args, config, self.output_dir_path) self.logger.log_config() # Randomness random.seed(config.random_seed) torch.manual_seed(config.random_seed) torch.cuda.manual_seed_all(config.random_seed) # Datasets if config.dataset == 'CIFAR10' or config.dataset == 'CIFAR100': self.input_size = 32 train_loader, test_loader, num_classes = cifar_dataloader.get_loaders(config.dataset, args.datadir, args.batch_size, args.num_workers) self.train_loader = train_loader self.test_loader = test_loader self.num_classes = num_classes elif config.dataset == 'GTSRB': self.input_size = 32 train_loader, val_loader, test_loader, num_classes = gtsrb_dataloader.cf_gtsrb.get_loaders(args.datadir, args.batch_size, args.num_workers) self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.num_classes = num_classes else: raise Exception("Dataset not supported: {}".format(config.dataset)) # Init starting values self.starting_epoch = 1 self.best_val_acc = 0 # Setup device if self.args.gpus is not None: self.args.gpus = [int(i) for i in self.args.gpus.split(',')] self.device = 'cuda:' + str(args.gpus[0]) torch.backends.cudnn.benchmark = True else: self.device = 'cpu' self.device = torch.device(self.device) # Setup model model = get_model(self.config, self.num_classes, self.input_size) # Resume model, if any if args.resume: print('Loading model checkpoint at: {}'.format(args.resume)) package = torch.load(args.resume, map_location=self.device) model_state_dict = package['state_dict'] #model_state_dict = utils.state_dict_retrocompatibility(model_state_dict) model.load_state_dict(model_state_dict, strict=args.strict) if args.pruned_retrain: for name, module in model.named_modules(): # prune 20% of connections in all 2D-conv layers if isinstance(module, torch.nn.Conv2d): prune.ln_structured(module, name='weight', amount=0.7, n=1, dim=0) self.model = model.to(device=self.device) if self.args.gpus is not None and len(self.args.gpus) > 1: self.model = nn.DataParallel(self.model, self.args.gpus) #Loss function self.criterion = nn.CrossEntropyLoss() self.criterion = self.criterion.to(device=self.device) # Init optimizer self.optimizer = self.model # setter syntax # Resume optimizer, if any if args.resume and not args.evaluate and not args.pruned_retrain: self.logger.log.info("Loading optimizer checkpoint") if 'optim_dict' in package.keys(): self.optimizer.load_state_dict(package['optim_dict']) if 'epoch' in package.keys(): self.starting_epoch = package['epoch'] # LR scheduler self.scheduler = self.optimizer # setter syntax # Resume scheduler, if any if args.resume \ and not args.evaluate \ and self.scheduler is not None and 'epoch' in package.keys(): self.scheduler.last_epoch = package['epoch'] - 1 # Recap self.logger.log_cmd_args()
def _prune_res_unit1(self, ratio=0.1): prune.ln_structured(list(self.conv1)[0], name="weight", amount=ratio, n=1, dim=0)
def prune_model(model, prune_protopyte): model = copy.deepcopy(model) prune_protopyte = copy.deepcopy(prune_protopyte) for idx, (data_1, data_2) in enumerate( zip(model.named_modules(), prune_protopyte.named_modules())): if idx == 0: continue name_1, module_1 = data_1[0], data_1[1] name_2, module_2 = data_2[0], data_2[1] if isinstance(module_1, nn.Conv2d) or isinstance(module_1, nn.Linear): w_shape_1 = torch.tensor(module_1.weight.shape) w_shape_2 = torch.tensor(module_2.weight.shape) w_diff = torch.abs(w_shape_1 - w_shape_2) if w_diff[0] > 0 or w_diff[1] > 0: if w_diff[0] > 0: prune.ln_structured(module_1, name="weight", amount=int(w_diff[0].item()), n=1, dim=0) if w_diff[1] > 0: prune.ln_structured(module_1, name="weight", amount=int(w_diff[1].item()), n=1, dim=1) mask = module_1.weight_mask w = torch.where(mask != 0) w_mask = torch.unique(w[0]) module_1.register_parameter('w_mask', nn.Parameter(w_mask.float())) continue if isinstance(module_1, nn.BatchNorm2d): w_shape_1 = torch.tensor(module_1.weight.shape) w_shape_2 = torch.tensor(module_2.weight.shape) w_diff = torch.abs(w_shape_1 - w_shape_2) if w_diff[0] > 0: prune.l1_unstructured(module_1, name="weight", amount=1.0) tree = [] tree_dict = {} for idx, (name, module) in enumerate(model.named_modules()): if idx == 0: continue if isinstance(module, nn.Conv2d): tree.append([name, 'Conv2d']) tree_dict[name] = 'Conv2d' if isinstance(module, nn.BatchNorm2d): tree.append([name, 'BatchNorm2d']) tree_dict[name] = 'BatchNorm2d' if isinstance(module, nn.Linear): tree.append([name, 'Linear']) tree_dict[name] = 'Linear' bn_dependencies = {} for idx, t in enumerate(tree): if t[1] == 'BatchNorm2d' and idx == 0: raise Exception('ERROR') if t[1] == 'BatchNorm2d': bn_dependencies[t[0]] = tree[idx - 1][0] prune_protopyte_state_dict = prune_protopyte.state_dict() for key in prune_protopyte.state_dict().keys(): prune_protopyte_state_dict[key].fill_(0) for layer in tree_dict.keys(): if f'{layer}.weight_orig' in model.state_dict().keys( ) and f'{layer}.weight_mask' in model.state_dict().keys(): if tree_dict[f'{layer}'] in ['Conv2d', 'Linear']: weights = model.state_dict()[f'{layer}.weight_orig'] mask = model.state_dict()[f'{layer}.weight_mask'] prune_protopyte_state_dict[f'{layer}.weight'] = weights[ mask.bool()].reshape( prune_protopyte_state_dict[f'{layer}.weight'].shape) if f'{layer}.bias' in model.state_dict().keys(): bias = model.state_dict()[f'{layer}.bias'] w_mask = model.state_dict()[f'{layer}.w_mask'].long() prune_protopyte_state_dict[f'{layer}.bias'] = bias[ w_mask].reshape( prune_protopyte_state_dict[f'{layer}.bias'].shape) continue if tree_dict[f'{layer}'] == 'BatchNorm2d': weights = model.state_dict()[f'{layer}.weight_orig'] running_mean = model.state_dict()[f'{layer}.running_mean'] running_var = model.state_dict()[f'{layer}.running_var'] w_mask = model.state_dict( )[f'{bn_dependencies[layer]}.w_mask'].long() prune_protopyte_state_dict[f'{layer}.weight'] = weights[ w_mask].reshape( prune_protopyte_state_dict[f'{layer}.weight'].shape) prune_protopyte_state_dict[ f'{layer}.running_mean'] = running_mean[w_mask].reshape( prune_protopyte_state_dict[f'{layer}.running_mean']. shape) prune_protopyte_state_dict[ f'{layer}.running_var'] = running_var[w_mask].reshape( prune_protopyte_state_dict[f'{layer}.running_var']. shape) if f'{layer}.bias' in model.state_dict().keys(): bias = model.state_dict()[f'{layer}.bias'] prune_protopyte_state_dict[f'{layer}.bias'] = bias[ w_mask].reshape( prune_protopyte_state_dict[f'{layer}.bias'].shape) continue else: if tree_dict[f'{layer}'] in ['Conv2d', 'Linear']: prune_protopyte_state_dict[ f'{layer}.weight'] = model.state_dict()[f'{layer}.weight'] if f'{layer}.bias' in model.state_dict().keys(): prune_protopyte_state_dict[ f'{layer}.bias'] = model.state_dict()[f'{layer}.bias'] if tree_dict[f'{layer}'] in ['Batch', 'BatchNorm2d']: prune_protopyte_state_dict[ f'{layer}.weight'] = model.state_dict()[f'{layer}.weight'] prune_protopyte_state_dict[ f'{layer}.running_mean'] = model.state_dict( )[f'{layer}.running_mean'] prune_protopyte_state_dict[ f'{layer}.running_var'] = model.state_dict( )[f'{layer}.running_var'] if f'{layer}.bias' in model.state_dict().keys(): prune_protopyte_state_dict[ f'{layer}.bias'] = model.state_dict()[f'{layer}.bias'] prune_protopyte.load_state_dict(prune_protopyte_state_dict) return prune_protopyte