def ntk_prune_adv(args, nets, data_loader, num_classes, samples_per_class=1): for net in nets: net.train() net.zero_grad() for layer in net.modules(): snip.add_mask_ones(layer) snip.add_criterion(layer) model = nets[0] for _ in range(10): utils.kaiming_initialize(model) # data_iter = iter(snip_loader) ntk = get_ntk_n(data_loader, [model], train_mode=True, num_batch=1, num_classes=num_classes, samples_per_class=samples_per_class) # torch.svd(ntk[0])[1].sum().backward() torch.norm(ntk[0]).backward() snip.update_criterion(model) snip.weight_mask_grad_zero(model) for net in nets: snip.net_prune_advsnip(net, args.sparse_lvl) print('[*] NTK pruning done!')
def apply_specprune(nets, sparsity): applied_layer = 0 for net in nets: for layer in net.modules(): if isinstance(layer, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): if sparsity[applied_layer] != 1: # add_mask_rand_basedonchannel(layer, sparsity[applied_layer], ratio, True, structured) specprune(layer, sparsity[applied_layer]) print('[*] Layer ' + str(applied_layer) + ' pruned!') else: snip.add_mask_ones(layer) applied_layer += 1 deactivate_mask_update(net)
def apply_svip_givensparsity(args, nets, sparsity): # first add masks to each layer of nets for net in nets: net.train() net.zero_grad() for layer in net.modules(): snip.add_mask_ones(layer) model = nets[0] num_iter = 10 # if args.iter_prune: # num_iter = round(math.log(args.sparse_lvl, 0.8)) # for i in range(num_iter): # loss = get_svip_loss(model) # loss.backward() # # prune the network using CS # for net in nets: # net_prune_svip(net, 0.8**(i+1)) applied_layer = 0 for layer in model.modules(): if isinstance(layer, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)): # print('Total param:',layer.weight.numel()) # if layer.weight.numel() * (1-sparsity[applied_layer])>1000: # num_iter = 10 # else: # num_iter = 1 if args.iter_prune: for i in range(num_iter): loss = get_svip_loss(layer) loss.backward() # net_prune_svip_layer(layer, sparsity[applied_layer]**((i+1)/num_iter)) net_prune_svip_layer_inv( layer, sparsity[applied_layer]**((i + 1) / num_iter)) # print('Actual:',(layer.weight_mask.sum()/layer.weight_mask.numel()).item()) # print('Expected:',sparsity[applied_layer]) else: loss = get_svip_loss(layer) loss.backward() net_prune_svip_layer(layer, sparsity[applied_layer]) applied_layer += 1 deactivate_mask_update(net) print('[*] SVIP pruning done!')
def apply_svip(args, nets): # first add masks to each layer of nets for net in nets: net.train() net.zero_grad() for layer in net.modules(): snip.add_mask_ones(layer) model = nets[0] # if args.iter_prune: # num_iter = round(math.log(args.sparse_lvl, 0.8)) # for i in range(num_iter): # loss = get_svip_loss(model) # loss.backward() # # prune the network using CS # for net in nets: # net_prune_svip(net, 0.8**(i+1)) if args.iter_prune: num_iter = 100 for i in range(num_iter): loss = get_svip_loss(model) loss.backward() # prune the network using CS for net in nets: net_prune_svip(net, args.sparse_lvl**((i + 1) / num_iter)) # svip_reinit(net) if i % 10 == 0: print('Prune ' + str(i) + ' iterations.') else: loss = get_svip_loss(model) loss.backward() # prune the network using CS for net in nets: net_prune_svip(net, args.sparse_lvl) deactivate_mask_update(net) print('[*] SVIP pruning done!')
def apply_rand_prune_givensparsity_var(nets, sparsity, ratio, structured, args): applied_layer = 0 with torch.no_grad(): for net in nets: for layer in net.modules(): if isinstance(layer, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): if sparsity[applied_layer] != 1: add_mask_rand_basedonchannel_shuffle( layer, sparsity[applied_layer], args.shuffle_ratio) # add_mask_rand_basedonchannel(layer, sparsity[applied_layer], ratio, True, structured) # snip.add_mask_rand(layer, sparsity[applied_layer], True, structured) else: snip.add_mask_ones(layer) layer.weight *= layer.weight_mask applied_layer += 1 deactivate_mask_update(net)
def ntk_ep_prune(args, nets, data_loader, num_classes, samples_per_class=1): print('[*] Using NTK+EP pruning.') print('[*] Coefficient used is {}'.format(args.ep_coe)) for net in nets: net.train() net.zero_grad() for layer in net.modules(): snip.add_mask_ones(layer) model = nets[0] if args.iter_prune: num_iter = 10 else: num_iter = 1 for i in range(num_iter): # data_iter = iter(snip_loader) ntk = get_ntk_n(data_loader, [model], train_mode=True, num_batch=1, num_classes=num_classes, samples_per_class=samples_per_class) # ntk = get_ntk_n(data_loader, [model], train_mode = True, num_batch=1) # torch.svd(ntk[0])[1].sum().backward() ntk_loss = torch.norm(ntk[0]) ep_loss = get_ep(model) (ntk_loss + args.ep_coe / ep_loss).backward() snip.net_iterative_prune(model, args.sparse_lvl**((i + 1) / num_iter)) snip.weight_mask_grad_zero(model) for module in model.modules(): if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)): weight_check = module.weight print(((weight_check != 0).float().sum() / weight_check.numel())) for net in nets: net.zero_grad() net.train() print('[*] NTK+EP pruning done!')
def get_zenscore(model, data_loader, arch, num_classes): # Clear stats of BN and make BN layer eval mode if arch == 'resnet18': # print('[*] Creating resnet18 clone without skip connection.') # model_clone = ResNet_wo_skip(BasicBlock_wo_skip, [2,2,2,2]).cuda() # model_clone.fc = nn.Linear(512, num_classes).cuda() # for (layer_clone, layer) in zip(model_clone.modules(), model.modules()): # if isinstance(layer, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): # if hasattr(layer, 'weight_mask'): # snip.add_mask_ones(layer_clone) # layer_clone.weight_mask.data = layer.weight_mask.data # print('[*] Creating resnet18 clone without skip connection.') model_clone = ResNet_with_gap(BasicBlock, [2,2,2,2]).cuda() model_clone.fc = nn.Linear(512, num_classes).cuda() for (layer_clone, layer) in zip(model_clone.modules(), model.modules()): if isinstance(layer, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): if hasattr(layer, 'weight_mask'): snip.add_mask_ones(layer_clone) layer_clone.weight_mask.data = layer.weight_mask.data else: model_clone = copy.deepcopy(model).cuda() for module in model_clone.modules(): if isinstance(module, nn.BatchNorm2d): module.running_mean.fill_(0) module.running_var.fill_(1) model_clone.train() data_iter = iter(data_loader) input_size = data_iter.next()[0].shape # Calculate Zen-Score nx = 10 ne = 10 GAP_zen = [] output_zen = [] with torch.no_grad(): for _ in range(nx): # for module in model_clone.modules(): # if hasattr(module,'weight') and isinstance(module,(nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): # module.weight.data = torch.randn(module.weight.size(),device=module.weight.device) for _ in range(ne): input = torch.empty(input_size) nn.init.normal_(input) noise = torch.empty(input_size) nn.init.normal_(noise) input = input.cuda() noise = noise.cuda() GAP_feature = model_clone.GAP(input) GAP_feature_perturb = model_clone.GAP(input+0.01*noise) output = model_clone(input) output_perturb = model_clone(input+0.01*noise) GAP_zen.append(torch.norm(GAP_feature_perturb-GAP_feature).item()) output_zen.append(torch.norm(output_perturb-output).item()) var_prod = 1.0 for layer in model_clone.modules(): if isinstance(layer, nn.BatchNorm2d): var = layer.running_var var_prod += np.log(np.sqrt((var.sum()/len(var)).item())) print('[*] Product of variances is: {}'.format(var_prod)) print('[*] Original Zen are: {},{}'.format(np.mean(GAP_zen), np.mean(output_zen))) GAP_zen = np.log(np.mean(GAP_zen))+var_prod output_zen = np.log(np.mean(output_zen))+var_prod del model_clone return [GAP_zen, output_zen]
def apply_zenprune(args, nets, data_loader): print('[*] Zen-Prune starts.') for net in nets: # net.eval() net.train() net.zero_grad() for layer in net.modules(): snip.add_mask_ones(layer) # if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear) or isinstance(layer, nn.ConvTranspose2d): # nn.init.normal_(layer.weight) model = nets[0] data_iter = iter(data_loader) imagesize = data_iter.next()[0].shape if args.iter_prune: num_iter = 100 else: num_iter = 1 n_x = 10 n_eta = 10 eta = 0.01 for i in range(num_iter): # Zero out gradients for weight_mask so to start a new round of iterative pruning model.zero_grad() for layer in model.modules(): if isinstance(layer, (nn.Conv2d, nn.Linear)): layer.weight_mask.grad = None if isinstance(layer, nn.BatchNorm2d): layer.running_mean.fill_(0) layer.running_var.fill_(1) for _ in range(n_x): # initialize weights drawn from Normal distribution N(0,1) # with torch.no_grad(): # for module in model.modules(): # if hasattr(module,'weight') and isinstance(module,(nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): # module.weight.data = torch.randn(module.weight.size(),device=module.weight.device) # Taking expectation w.r.t eta for _ in range(n_eta): input = torch.empty(imagesize) nn.init.normal_(input) noise = torch.empty(imagesize) nn.init.normal_(noise) input = input.cuda() noise = noise.cuda() output = model(input) output_perturb = model(input + 0.01 * noise) zen_score = torch.norm(output - output_perturb) zen_score.backward() # snip.prune_net_increaseloss(model, args.sparse_lvl**((i+1)/num_iter)) # snip.net_iterative_prune(model, args.sparse_lvl**((i+1)/num_iter)) snip.net_prune_grasp(model, args.sparse_lvl**((i + 1) / num_iter)) if i % 5 == 0: print('Prune ' + str(i) + ' iterations.') print('Zen-score is {}'.format(zen_score.item())) snip.deactivate_mask_update(model) for module in model.modules(): if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)): weight_check = module.weight_mask print(((weight_check != 0).float().sum() / weight_check.numel())) print('-' * 20) # # remove hooks # for handle in handles: # handle.remove() # zero out gradients of weights for net in nets: net.zero_grad() net.train()
def apply_SAP(args, nets, data_loader, criterion, num_classes, samples_per_class=10): print('[*] Currently using SAP pruning.') for net in nets: # net.eval() net.train() net.zero_grad() for layer in net.modules(): snip.add_mask_ones(layer) model = nets[0] data_iter = iter(data_loader) imagesize = data_iter.next()[0].shape if args.iter_prune: num_iter = 100 else: num_iter = 1 for layer in model.modules(): if isinstance(layer, (nn.Conv2d, nn.Linear)): layer.base_weight = layer.weight.detach() n_x = 10 n_eta = 10 eta = 0.01 for i in range(num_iter): # Taking expectaion w.r.t x model.zero_grad() for layer in model.modules(): if isinstance(layer, (nn.Conv2d, nn.Linear)): layer.weight_mask.grad = None for _ in range(n_x): try: (input, target) = snip.GraSP_fetch_data(data_iter, num_classes, samples_per_class) except: data_iter = iter(data_loader) (input, target) = snip.GraSP_fetch_data(data_iter, num_classes, samples_per_class) target_var = target.cuda() input_var = input.cuda() # Taking expectation w.r.t eta for _ in range(n_eta): with torch.no_grad(): for layer in model.modules(): if isinstance(layer, (nn.Conv2d, nn.Linear)): layer.weight.data = layer.base_weight + eta * torch.randn( layer.weight.size(), device=layer.weight.device) # compute output output = model(input_var) loss = criterion(output, target_var) loss.backward() ################################################## # This part is for adversarial perturbations ################################################## # output = model(input_var) # loss = criterion(output, target_var) # loss.backward() # with torch.no_grad(): # for layer in model.modules(): # if isinstance(layer,(nn.Conv2d, nn.Linear)): # layer.weight_mask.grad = None # layer.weight.data = layer.base_weight + eta*layer.weight.grad/torch.norm(eta*layer.weight.grad) # # compute output # output = model(input_var) # loss = criterion(output, target_var) # loss.backward() with torch.no_grad(): for layer in model.modules(): if isinstance(layer, (nn.Conv2d, nn.Linear)): layer.weight.data = layer.base_weight snip.net_iterative_prune(model, args.sparse_lvl**((i + 1) / num_iter)) # snip.prune_net_decreaseloss(model, args.sparse_lvl**((i+1)/num_iter), True) if i % 10 == 0: print('Prune ' + str(i) + ' iterations') for module in model.modules(): if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)): weight_check = module.weight print(((weight_check != 0).float().sum() / weight_check.numel())) print('-' * 20) for net in nets: net.zero_grad() net.train()
def apply_nsprune(args, nets, data_loader, num_classes, samples_per_class=10, GAP=True): print('Using GAP is {}'.format(GAP)) for net in nets: net.eval() # net.train() net.zero_grad() for layer in net.modules(): snip.add_mask_ones(layer) model = nets[0] data_iter = iter(data_loader) imagesize = data_iter.next()[0].shape if args.iter_prune: num_iter = 100 else: num_iter = 1 n_x = 10 n_eta = 10 eta = 0.01 for i in range(num_iter): # Taking expectaion w.r.t x model.zero_grad() for layer in model.modules(): if isinstance(layer, (nn.Conv2d, nn.Linear)): layer.weight_mask.grad = None ns_tracking = 0 for _ in range(n_x): try: (input, target) = snip.GraSP_fetch_data(data_iter, num_classes, samples_per_class) except: data_iter = iter(data_loader) (input, target) = snip.GraSP_fetch_data(data_iter, num_classes, samples_per_class) input = input.cuda() norm_x = torch.norm(input) # Taking expectation w.r.t eta for _ in range(n_eta): # noise = torch.ones(input.size())*torch.randn(1,)*eta*norm_x noise = torch.randn(input.size()) * eta * norm_x input_perturb = input + noise.cuda() if GAP: output = model.GAP(input) output_perturb = model.GAP(input_perturb) else: output = model(input) output_perturb = model(input_perturb) perturbation = torch.norm(output - output_perturb) / norm_x perturbation.backward() ns_tracking += perturbation.item() snip.net_iterative_prune_wolinear( model, args.sparse_lvl**((i + 1) / num_iter)) # snip.prune_net_decreaseloss(model, args.sparse_lvl**((i+1)/num_iter), True) if i % 10 == 0: print('Prune ' + str(i) + ' iterations, noise sensitivity:{}'.format(ns_tracking)) for module in model.modules(): if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)): weight_check = module.weight print(((weight_check != 0).float().sum() / weight_check.numel())) print('-' * 20) for net in nets: net.zero_grad() net.train()
def apply_synflow(args, nets, data_loader): @torch.no_grad() def linearize(model): signs = {} for name, param in model.state_dict().items(): signs[name] = torch.sign(param) param.abs_() return signs @torch.no_grad() def nonlinearize(model, signs): for name, param in model.state_dict().items(): param.mul_(signs[name]) print('[*] Synflow pruner starts.') for net in nets: # To use Synflow, we have to set Batchnorm to eval mode net.eval() net.zero_grad() for layer in net.modules(): snip.add_mask_ones(layer) # Reset batchnorm statistics for layer in model.modules(): if isinstance(layer, (nn.Conv2d, nn.Linear)): layer.weight_mask.grad = None if isinstance(layer, nn.BatchNorm2d): layer.running_mean.fill_(0) layer.running_var.fill_(1) model = nets[0] data_iter = iter(data_loader) (data, _) = next(data_iter) input_dim = list(data[0, :].shape) input = torch.ones([1] + input_dim).to( device) #, dtype=torch.float64).to(device) if args.iter_prune: num_iter = 100 else: num_iter = 1 signs = linearize(model) for i in range(num_iter): # Zero out gradients for weight_mask so to start a new round of iterative pruning model.zero_grad() output = model(input) torch.sum(output).backward() # snip.prune_net_increaseloss(model, args.sparse_lvl**((i+1)/num_iter)) # snip.net_iterative_prune(model, args.sparse_lvl**((i+1)/num_iter)) snip.net_prune_grasp(model, args.sparse_lvl**((i + 1) / num_iter)) if i % 5 == 0: print('Prune ' + str(i) + ' iterations.') snip.deactivate_mask_update(model) for module in model.modules(): if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)): weight_check = module.weight_mask print(((weight_check != 0).float().sum() / weight_check.numel())) print('-' * 20) # # remove hooks # for handle in handles: # handle.remove() # zero out gradients of weights for net in nets: net.zero_grad() net.train() # net.reinit() nonlinearize(model)