Exemplo n.º 1
0
def create_identity_mask(model, tensor_name="weight"):
    for _name, module in model.named_modules():
        try:
            prune.identity(module, tensor_name)
        except Exception as e:
            print(e)
            pass
Exemplo n.º 2
0
def init_prune_model(model):
    if prune.is_pruned(model):
        remove_pruning(model)
    for layer in model.children():
        if isinstance(layer, nn.Linear):
            prune.identity(layer, name='weight')
        if isinstance(layer, nn.Conv2d):
            prune.identity(layer, name='weight')
    def init_mask(self, model):
        """
        Constructs initial masks for all the layers in the network. Each mask is essentially a matrix of ones.

        No masks are constructed for the biases

        Arguments
        -------
        model: (nn.Module), the feed forward network to prune
        """
        for n, m in model.named_children():
            if hasattr(m, 'weight'):
                prune.identity(m, name='weight')
                self.weight_mask[f"{n}.weight"] = m.weight_mask.detach().clone()
                prune.remove(m, name='weight')
Exemplo n.º 4
0
 def detect(self, **kwargs):
     super().detect(**kwargs)
     module_list = list(self.model.named_modules())
     for name, module in reversed(module_list):
         if isinstance(module, nn.Conv2d):
             self.prune_layer: str = name
             self.conv_module: nn.Module = prune.identity(module, 'weight')
             break
     length = self.conv_module.out_channels
     self.prune_num: int = int(length * self.prune_ratio)
     self.prune(**kwargs)
Exemplo n.º 5
0
    def prune_step_change(self, pstep, prune_mode):
        cuda_using = next(self.parameters()).is_cuda
        #reimport fixed u and v weights
        ui, vi = self.load_weights()
        #fix current weights
        uc, vc = (self.u_embeddings.weight.data.clone().cpu(),
                  self.v_embeddings.weight.data.clone().cpu())
        #fix current masks
        if not list(self.u_embeddings.named_buffers()):
            #prune.identity(self.u_embeddings, name='weight')
            prune.identity(self.v_embeddings, name='weight')
        #umask = dict(self.u_embeddings.named_buffers())['weight_mask'].cpu()
        vmask = dict(self.v_embeddings.named_buffers())['weight_mask'].cpu()

        #u_temp = torch.nn.Embedding(self.vocab_size, self.emb_dimension)
        v_temp = torch.nn.Embedding(self.vocab_size, self.emb_dimension)

        if prune_mode == 'change':
            f = lambda x, y: x - y
        elif prune_mode == 'absolute change':
            f = lambda x, y: torch.abs(x) - torch.abs(y)
        else:
            f = lambda x, y: x
        # weights to be left must have higher function outputs
        #u_temp.weight.data.copy_(f(uc,ui))
        v_temp.weight.data.copy_(f(vc, vi))

        #prune.custom_from_mask(u_temp,name='weight',mask=umask)
        prune.custom_from_mask(v_temp, name='weight', mask=vmask)
        if cuda_using:
            #    u_temp.cuda()
            v_temp.cuda()
        #prune.l1_unstructured(u_temp, name='weight', amount=pstep)
        prune.l1_unstructured(v_temp, name='weight', amount=pstep)
        #checked, cuda <-> cpu crash DNE
        #u_temp.weight.data.copy_(uc)
        v_temp.weight.data.copy_(vc)

        #self.u_embeddings = u_temp
        self.v_embeddings = v_temp
Exemplo n.º 6
0
    def prune(self, **kwargs):
        for name, module in reversed(list(self.model.named_modules())):
            if isinstance(module, nn.Conv2d):
                self.last_conv: nn.Conv2d = prune.identity(module, 'weight')
                break
        length = self.last_conv.out_channels

        mask: torch.Tensor = self.last_conv.weight_mask
        self.prune_step(mask, prune_num=max(self.prune_num - 10, 0))
        self.attack.validate_fn()

        for i in range(min(10, length)):
            print('Iter: ', output_iter(i + 1, 10))
            self.prune_step(mask, prune_num=1)
            _, clean_acc = self.attack.validate_fn()
            if self.attack.clean_acc - clean_acc > 20:
                break
        file_path = os.path.join(self.folder_path,
                                 self.get_filename() + '.pth')
        self.model._train(validate_fn=self.attack.validate_fn,
                          file_path=file_path,
                          **kwargs)
        self.attack.validate_fn()
Exemplo n.º 7
0
def prune_by_masks(net, masks):
    pr.identity(net.l1, "weight")
    pr.identity(net.l1, "bias")
    net.l1.weight_mask = masks[0].to(device)
    net.l1.bias_mask = masks[1].to(device)
    if net.l >= 2:
        pr.identity(net.l2, "weight")
        pr.identity(net.l2, "bias")
        net.l2.weight_mask = masks[2].to(device)
        net.l2.bias_mask = masks[3].to(device)
        if net.l >= 3:
            pr.identity(net.l3, "weight")
            pr.identity(net.l3, "bias")
            net.l3.weight_mask = masks[4].to(device)
            net.l3.bias_mask = masks[5].to(device)
    pr.identity(net.out, "weight")
    pr.identity(net.out, "bias")
    net.out.weight_mask = masks[-2].to(device)
    net.out.bias_mask = masks[-1].to(device)
Exemplo n.º 8
0
def init_for_pruning(smodel, to_be_pruned: [str], last_layer_index=-1):

    # don't prune the output layer
    for layer in list(smodel.children())[:last_layer_index]:
        for pname in to_be_pruned:
            identity(layer, pname)