def create_identity_mask(model, tensor_name="weight"): for _name, module in model.named_modules(): try: prune.identity(module, tensor_name) except Exception as e: print(e) pass
def init_prune_model(model): if prune.is_pruned(model): remove_pruning(model) for layer in model.children(): if isinstance(layer, nn.Linear): prune.identity(layer, name='weight') if isinstance(layer, nn.Conv2d): prune.identity(layer, name='weight')
def init_mask(self, model): """ Constructs initial masks for all the layers in the network. Each mask is essentially a matrix of ones. No masks are constructed for the biases Arguments ------- model: (nn.Module), the feed forward network to prune """ for n, m in model.named_children(): if hasattr(m, 'weight'): prune.identity(m, name='weight') self.weight_mask[f"{n}.weight"] = m.weight_mask.detach().clone() prune.remove(m, name='weight')
def detect(self, **kwargs): super().detect(**kwargs) module_list = list(self.model.named_modules()) for name, module in reversed(module_list): if isinstance(module, nn.Conv2d): self.prune_layer: str = name self.conv_module: nn.Module = prune.identity(module, 'weight') break length = self.conv_module.out_channels self.prune_num: int = int(length * self.prune_ratio) self.prune(**kwargs)
def prune_step_change(self, pstep, prune_mode): cuda_using = next(self.parameters()).is_cuda #reimport fixed u and v weights ui, vi = self.load_weights() #fix current weights uc, vc = (self.u_embeddings.weight.data.clone().cpu(), self.v_embeddings.weight.data.clone().cpu()) #fix current masks if not list(self.u_embeddings.named_buffers()): #prune.identity(self.u_embeddings, name='weight') prune.identity(self.v_embeddings, name='weight') #umask = dict(self.u_embeddings.named_buffers())['weight_mask'].cpu() vmask = dict(self.v_embeddings.named_buffers())['weight_mask'].cpu() #u_temp = torch.nn.Embedding(self.vocab_size, self.emb_dimension) v_temp = torch.nn.Embedding(self.vocab_size, self.emb_dimension) if prune_mode == 'change': f = lambda x, y: x - y elif prune_mode == 'absolute change': f = lambda x, y: torch.abs(x) - torch.abs(y) else: f = lambda x, y: x # weights to be left must have higher function outputs #u_temp.weight.data.copy_(f(uc,ui)) v_temp.weight.data.copy_(f(vc, vi)) #prune.custom_from_mask(u_temp,name='weight',mask=umask) prune.custom_from_mask(v_temp, name='weight', mask=vmask) if cuda_using: # u_temp.cuda() v_temp.cuda() #prune.l1_unstructured(u_temp, name='weight', amount=pstep) prune.l1_unstructured(v_temp, name='weight', amount=pstep) #checked, cuda <-> cpu crash DNE #u_temp.weight.data.copy_(uc) v_temp.weight.data.copy_(vc) #self.u_embeddings = u_temp self.v_embeddings = v_temp
def prune(self, **kwargs): for name, module in reversed(list(self.model.named_modules())): if isinstance(module, nn.Conv2d): self.last_conv: nn.Conv2d = prune.identity(module, 'weight') break length = self.last_conv.out_channels mask: torch.Tensor = self.last_conv.weight_mask self.prune_step(mask, prune_num=max(self.prune_num - 10, 0)) self.attack.validate_fn() for i in range(min(10, length)): print('Iter: ', output_iter(i + 1, 10)) self.prune_step(mask, prune_num=1) _, clean_acc = self.attack.validate_fn() if self.attack.clean_acc - clean_acc > 20: break file_path = os.path.join(self.folder_path, self.get_filename() + '.pth') self.model._train(validate_fn=self.attack.validate_fn, file_path=file_path, **kwargs) self.attack.validate_fn()
def prune_by_masks(net, masks): pr.identity(net.l1, "weight") pr.identity(net.l1, "bias") net.l1.weight_mask = masks[0].to(device) net.l1.bias_mask = masks[1].to(device) if net.l >= 2: pr.identity(net.l2, "weight") pr.identity(net.l2, "bias") net.l2.weight_mask = masks[2].to(device) net.l2.bias_mask = masks[3].to(device) if net.l >= 3: pr.identity(net.l3, "weight") pr.identity(net.l3, "bias") net.l3.weight_mask = masks[4].to(device) net.l3.bias_mask = masks[5].to(device) pr.identity(net.out, "weight") pr.identity(net.out, "bias") net.out.weight_mask = masks[-2].to(device) net.out.bias_mask = masks[-1].to(device)
def init_for_pruning(smodel, to_be_pruned: [str], last_layer_index=-1): # don't prune the output layer for layer in list(smodel.children())[:last_layer_index]: for pname in to_be_pruned: identity(layer, pname)