Exemplo n.º 1
0
    def process(self, data: Any) -> Any:
        """
        Implements a Processor for global pruning of model weights.
        """

        if self.model is None:
            return data

        if self.keep_model:
            original_state_dict = copy.deepcopy(self.model.cpu().state_dict())

        prune.global_unstructured(
            self.parameters_to_prune,
            pruning_method=self.pruning_method,
            amount=self.amount,
        )

        for module, name in self.parameters_to_prune:
            prune.remove(module, name)

        output = self.model.cpu().state_dict()

        if self.keep_model:
            self.model.load_state_dict(original_state_dict)

        logging.info("[Client #%d] Global pruning applied.", self.client_id)

        return output
Exemplo n.º 2
0
def weight_prune(prune_iter):
    conv_rate = (1 - ((1 - args.prune_per_conv)**prune_iter))
    fc_rate = (1 - ((1 - args.prune_per_linear)**prune_iter))
    out_rate = (1 - ((1 - args.prune_per_out)**prune_iter))
    # make prune mask
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=conv_rate)
        if isinstance(module, nn.Linear):
            if 'out' in name:
                prune.l1_unstructured(module, name='weight', amount=out_rate)
            else:
                prune.l1_unstructured(module, name='weight', amount=fc_rate)

    # mask copy
    cpd_mask = {}
    for name, mask in model.named_buffers():
        cpd_mask[name] = mask

    # going prune
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.remove(module, name='weight')
        elif isinstance(module, nn.Linear):
            prune.remove(module, name='weight')
    # return copied mask
    return cpd_mask
Exemplo n.º 3
0
def prune_model(model, amount, prune_mask, method=prune.L1Unstructured):
    model.to('cpu')
    model.mask_to_device('cpu')
    for name, module in model.named_modules():  # re-apply current mask to the model
        if isinstance(module, torch.nn.Linear):
#            if name is not "fc4":
             prune.custom_from_mask(module, "weight", prune_mask[name])

    parameters_to_prune = (
        (model.fc1, 'weight'),
        (model.fc2, 'weight'),
        (model.fc3, 'weight'),
        (model.fc4, 'weight'),
    )
    prune.global_unstructured(  # global prune the model
        parameters_to_prune,
        pruning_method=method,
        amount=amount,
    )

    for name, module in model.named_modules():  # make pruning "permanant" by removing the orig/mask values from the state dict
        if isinstance(module, torch.nn.Linear):
#            if name is not "fc4":
            torch.logical_and(module.weight_mask, prune_mask[name],
                              out=prune_mask[name])  # Update progress mask
            prune.remove(module, 'weight')  # remove all those values in the global pruned model

    return model
Exemplo n.º 4
0
def prune_model(model, prune_type, prune_percent):
    ''' Sparsifies (L1) model weights with either global or layerwise prune_percent. Currently only pruning Conv2D.
    '''
    if prune_type == 'global':
        print('Globally pruning all Conv2d layers with {} sparsity'.format(
            prune_percent))
        parameters_to_prune = []
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                parameters_to_prune.append((module, 'weight'))

        prune.global_unstructured(tuple(parameters_to_prune),
                                  pruning_method=prune.L1Unstructured,
                                  amount=prune_percent)

    elif prune_type == 'layerwise':
        print('Layerwise pruning all Conv2d layers with {} sparsity'.format(
            prune_percent))
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                prune.l1_unstructured(module,
                                      name='weight',
                                      amount=prune_percent)

    else:
        print('Unknown pruning method: {}'.format(prune_type))

    # make pruning permenant
    # otherwise subsequent Coronal and Sagittal model calls fail due to weight name mismatch
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.remove(module, 'weight')
    return model
Exemplo n.º 5
0
def lottery_fl_v3(server_model, models, dataset, arch, data_nums):
    new_model = create_model(
        dataset, arch
    )  #copy_model(server_model, dataset, arch, source_buff=dict(server_model.named_buffers()))
    num_models = len(models)
    num_data_total = sum(data_nums)
    with torch.no_grad():
        # Getting all the weights and masks from original models
        weights, masks = [], []
        for i in range(num_models):
            new_c_model = copy_model(models[i], dataset, arch)
            parameters_to_prune, _, _ = get_prune_params(new_c_model)
            for m, n in parameters_to_prune:
                prune.remove(m, n)
            weights.append(dict(new_c_model.named_parameters()))

        for name, param in new_model.named_parameters():
            param.data.copy_(torch.zeros_like(param.data))
        # Averaging weights
        for name, param in new_model.named_parameters():
            for i in range(num_models):
                weighted_param = weights[i][name.strip(
                    "_orig")]  #torch.mul(weights[i][name], data_nums[i])
                param.data.copy_(param.data + weighted_param)
            avg = torch.div(param.data, num_models)
            param.data.copy_(avg)
    return new_model
Exemplo n.º 6
0
    def prune_updates(self, previous_weights):
        """ Prune aggregated updates. """

        updates = self.compute_weight_updates(previous_weights)
        updates_model = models_registry.get()
        updates_model.load_state_dict(updates, strict=True)

        parameters_to_prune = []
        for _, module in updates_model.named_modules():
            if isinstance(module, torch.nn.Conv2d) or isinstance(
                    module, torch.nn.Linear):
                parameters_to_prune.append((module, 'weight'))

        if hasattr(Config().clients, 'pruning_method') and Config(
        ).clients.pruning_method == 'random':
            pruning_method = prune.RandomUnstructured
        else:
            pruning_method = prune.L1Unstructured

        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=pruning_method,
            amount=Config().clients.pruning_amount,
        )

        for module, name in parameters_to_prune:
            prune.remove(module, name)

        return updates_model.cpu().state_dict()
 def _prune(self, ):
     prune.global_unstructured(self.parameters_to_prune(self.model),
                               pruning_method=prune.L1Unstructured,
                               amount=(1 - self.args.prune_ratio))
     # remove the pruning reparameterization
     for module, name in self.parameters_to_prune(self.model):
         prune.remove(module, 'weight')
Exemplo n.º 8
0
def remove_prune_params(model):
    for key, value in model.named_modules():
        if isinstance(value, nn.Conv2d):
            prune.remove(value, name='weight')
        elif isinstance(value, nn.Linear):
            prune.remove(value, name='weight')
    return model
Exemplo n.º 9
0
    def prune(self, px, path):
        parameters_to_prune =[]
        for m in self.net.G.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                parameters_to_prune.append((m,'weight'))
        parameters_to_prune = tuple(parameters_to_prune)
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=px,
        )

        new_dict = {}
        state_dict = self.netG.model_dict()
        for key in state_dict.keys():
            if 'mask' in key:
                new_dict[key] = state_dict[key]
        for name,m in self.netG.named_modules():
            if isinstance(m, nn.Conv2d):
                prune.remove(m,'weight')

        self.netD.load_state_dict(os.path.join(path, 'init_D.pth.tar'))
        self.netG.load_state_dict(os.path.join(path, 'init_G.pth.tar'))

        for name,m in self.netG.named_modules():
            if isinstance(m, nn.Conv2d):
                print('pruning layer with custom mask:', name)
                prune.CustomFromMask.apply(m, 'weight', mask=new_dict[name+'.weight_mask'])
Exemplo n.º 10
0
def pruneWholeModel(module,blockList):
    global st
    global en
    global layer_number
    global candidateConvLayer
    candidateConvLayer = []
    global newList
    newList = []
    
    for bl in range(len(blockList)):    
        
        if bl==0:
            st = 0
        else:
            st=en
        
        en = en+blockList[bl]

        if bl<=3:
            continue
            
        newList = []
        for i in range(st,en):
            layer_number =i
            candidateConvLayer.append(fp.compute_distance_score(module[i]._parameters['weight'],threshold=1))
            #fp.sort_kernel_by_distance(candidateConvLayer[i])
        for i in range(st,en):
            newList.append( fp.get_k_element(channel_list=candidateConvLayer[i],k=prune_count[i]) )
            kernal_unstructured(module=module[i],name='weight')
            prune.remove(module[i], 'weight')
    def _remove_res_unit1(self):

        list(
            self.conv1
        )[0].weight_mask.data = list(self.conv1)[0].weight_mask.data * 0 + 1

        prune.remove(list(self.conv1)[0], name="weight")
Exemplo n.º 12
0
 def make_pruning_permanent(self):
     """ Makes ``parameters_to_prune`` current pruning permanent. """
     for module, param_name in self._parameters_to_prune:
         try:
             pytorch_prune.remove(module, param_name)
         except ValueError:
             # pruning already made permanent
             pass
Exemplo n.º 13
0
def prune(model, amount=0.3):
    # Prune model to requested global sparsity
    import torch.nn.utils.prune as prune
    # print('Pruning model... ', end='')
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            prune.l1_unstructured(m, name='weight', amount=amount)  # prune
            prune.remove(m, 'weight')  # make permanent
Exemplo n.º 14
0
def mask_merger(model):
    """remove mask but let weights stay pruned"""
    for name, module in model.named_modules():
        if is_pruned(module) == False:
            continue
        if isinstance(module, torch.nn.Conv2d) or isinstance(
                module, torch.nn.Linear):
            remove(module, name='weight')
Exemplo n.º 15
0
def neural_network_pruner(model):

    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=0.6)
            prune.remove(module, "weight")
        elif isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=0.5)
            prune.remove(module, "weight")
Exemplo n.º 16
0
def prune(model, amount=0.3):
    # Prune model to requested global sparsity
    import torch.nn.utils.prune as prune
    print("Pruning model... ", end="")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            prune.l1_unstructured(m, name="weight", amount=amount)  # prune
            prune.remove(m, "weight")  # make permanent
    print(f" {sparsity(model):.3f} global sparsity")
Exemplo n.º 17
0
    def finishpruning(self):
        """Finalizes pruning by removing the parallel set of original unpruned
        weights.

        NB: Changes the model names different from original. Must be kept in
        mind while loading weights.
        """
        for x in self.prun_param:
            prune.remove(x[0], "weight")
Exemplo n.º 18
0
def prune_layer(layer, prune_rate, reset=True):
    """ Prune given 'layer' with 'prune_rate' and reset surviving weights to their initial values, if 'reset' is True.
    Calls with 'prune_rate'=0.0 do not remove weights. """
    if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
        pruned_mask = prune_mask(layer, prune_rate)
        if reset:
            prune.remove(layer, name='weight')  # temporarily remove pruning
            layer.weight = nn.Parameter(
                layer.weight_init.clone())  # set weights to initial weights
        prune.custom_from_mask(layer, name='weight',
                               mask=pruned_mask)  # apply pruned mask
Exemplo n.º 19
0
def unPruneNetwork(net):
    """
    Remove the pruning hooks and masks of weights from a pruned network
    :param net: The network to be pruned
    :return: None
    """
    for name, module in net.named_modules():
        for _, hook in module._forward_pre_hooks.items():
            if isinstance(hook, prune.BasePruningMethod):
                prune.remove(module, "weight")
                continue
Exemplo n.º 20
0
    def prune_and_save_model(model, amount):

        for _, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
                prune.l1_unstructured(module, name="weight", amount=amount)
                prune.remove(module, "weight")

        mlflow.pytorch.save_state_dict(model.state_dict(), ".")
        model = torch.load("state_dict.pth")
        os.remove("state_dict.pth")
        return model
Exemplo n.º 21
0
 def upload(self, *args, **kwargs) -> Dict[nn.Module, float]:
     """
         Upload self.model
     """
     upload_model = copy_model(model=self.model, device=self.args.device)
     params_pruned = get_prune_params(upload_model, name='weight')
     for param, name in params_pruned:
         prune.remove(param, name)
     return {
         'model': upload_model,
         'acc': self.accuracies[-1]
     }
Exemplo n.º 22
0
    def setup_graph(self):
        # initialize the masks
        if self._mask_init_method == 'random':
            for i, mask in enumerate(self.masks):
                size = mask.size()
                self.masks[i] = torch.tensor(get_mask_random(size, self._default_sparsity), dtype=mask.dtype).to(device)

        # initialize masked weight
        for i, layer in enumerate(self.layers):
            if prune.is_pruned(layer):
                prune.remove(layer, 'weight')
            prune.custom_from_mask(layer, 'weight', self.masks[i])
Exemplo n.º 23
0
def prune_net(net, threshold):
    """
    Prunes the network according to the unstructured L1 method
    and the parameters specified in __modules_to_prune__
    """
    # TODO: structured pruning should be added as a possibility
    parameters = modules_to_prune(net)
    prune.global_unstructured(parameters,
                              pruning_method=prune.L1Unstructured,
                              amount=threshold)
    for module, name in parameters:
        prune.remove(module, name)
    return net
Exemplo n.º 24
0
def do_prune(cnn_module, threshold):
    # do pruning
    parameters_to_prune = [(cnn_module.conv1, 'weight'),
                           (cnn_module.conv2, 'weight'),
                           (cnn_module.conv3, 'weight'),
                           (cnn_module.fc1, 'weight'),
                           (cnn_module.fc2, 'weight')]
    prune.global_unstructured(parameters_to_prune,
                              pruning_method=ThresholdPruningMethod,
                              threshold=threshold)

    for child in cnn_module.children():
        prune.remove(child, 'weight')
    def stress_test ():
        while(True):
            C = random.randint(1, 100) % 65 + 3
            K = random.randint(1, 100) % 65 + 16
            R = random.randint(1, 100) % 8 + 1
            cfg.xbar_row_size = random.randint(1, 100) % 64 + 1
            xbar_strategy  = random.randint(0,1)
            xbar_strategy_name  = {0:"dynamic", 1:"static"}
            #C, K, R = 3, 3, 1
            #cfg.xbar_row_size = 1

            class LeNet(nn.Module):
                def __init__(self):
                    super(LeNet, self).__init__()
                    #self.conv1 = nn.Conv2d(C, K, R)
                    self.fc1 = nn.Linear(C, K)

                def forward(self, x):
                    #x = self.conv1(x)
                    x = self.fc1(x)
                    return x

            model = LeNet().to(device=device)
    
            def sparsity(weight):
                return float(torch.sum(weight==0.0))/float(weight.nelement())

            #module = model.conv1
            #module.weight = torch.nn.Parameter(torch.arange(0,C*K*R*R).view(K, C, R, R).float())
            module = model.fc1
            #module.weight = torch.nn.Parameter(torch.arange(0,C*K).view(K, C).float())
    
            try:
                # Prune a model initially - unstructured pruning
                if (xbar_strategy_name[xbar_strategy] == 'dynamic'):
                    prune.l1_unstructured(module, name="weight", amount=0.5)
                    prune.remove(module, 'weight')
                s_mat = sparsity(module.weight)

                # Fine-tune with xbar-aware pruning
                l1_xbar_unstructured(module, name="weight", threshold=0.5, xbar_strategy=xbar_strategy_name[xbar_strategy])
                s_xbar = sparsity(module.weight)
                
                #print("{:.2f}" .format(s_xbar))
                prune.remove(module, 'weight')
                print ('Passed:\t Mat {0:0.2f}\t Xbar {1:0.2f} \t [C {2}, K {3}, R {4}, xbar_row_size {5}] Xbar strategy {6}' 
                .format(s_mat, s_xbar, C, K, R, cfg.xbar_row_size, xbar_strategy_name[xbar_strategy]))
                assert (s_xbar >= s_mat)
            except Exception as e:
                print ("Failed configuration [C, K, R, xbar_row_size, strategy] ", [C, K, R, cfg.xbar_row_size, xbar_strategy_name[xbar_strategy]])
                raise e
Exemplo n.º 26
0
def remove_prune_model_custom(model):

    parameters_to_prune =[]
    for ii in range(12):
        parameters_to_prune.append(model.bert.encoder.layer[ii].attention.self.query)
        parameters_to_prune.append(model.bert.encoder.layer[ii].attention.self.key)
        parameters_to_prune.append(model.bert.encoder.layer[ii].attention.self.value)
        parameters_to_prune.append(model.bert.encoder.layer[ii].attention.output.dense)
        parameters_to_prune.append(model.bert.encoder.layer[ii].intermediate.dense)
        parameters_to_prune.append(model.bert.encoder.layer[ii].output.dense)
    # parameters_to_prune.append(model.bert.pooler.dense)

    for idx in range(len(parameters_to_prune)):
        prune.remove(parameters_to_prune[idx], 'weight')
Exemplo n.º 27
0
    def pruning(self):
        parameters_to_prune = self.get_prune_parameters()
        prune_method = self.get_prune_method()
        prune.global_unstructured(
            parameters=parameters_to_prune,
            pruning_method=prune_method,
            amount=self.percentage,
        )

        # clean up re-parameterization
        for module, name in parameters_to_prune:
            prune.remove(module, name)

        return self.model
def prune_transformer_block(transformer_block, args):
    pruning_amount = float(args.pruning_amount)
    prune.ln_structured(transformer_block.fc1,
                        name='weight',
                        amount=pruning_amount,
                        n=0,
                        dim=0)
    prune.remove(transformer_block.fc1, 'weight')
    prune.ln_structured(transformer_block.fc2,
                        name='weight',
                        amount=pruning_amount,
                        n=0,
                        dim=0)
    prune.remove(transformer_block.fc2, 'weight')
    for sub_module in transformer_block.fc_delta:
        if isinstance(sub_module, torch.nn.Linear):
            prune.ln_structured(sub_module,
                                name='weight',
                                amount=pruning_amount,
                                n=0,
                                dim=0)
            prune.remove(sub_module, 'weight')
    for sub_module in transformer_block.fc_gamma:
        if isinstance(sub_module, torch.nn.Linear):
            prune.ln_structured(sub_module,
                                name='weight',
                                amount=pruning_amount,
                                n=0,
                                dim=0)
            prune.remove(sub_module, 'weight')
    return transformer_block
Exemplo n.º 29
0
 def remove(self):
     prune.remove(self.conv1, name='weight')
     prune.remove(self.conv2, name='weight')
     prune.remove(self.conv3, name='weight')
     if self.se != None:
         for c in self.se.se:
             if isinstance(c, nn.Conv2d):
                 prune.remove(c, name='weight')
    def init_mask(self, model):
        """
        Constructs initial masks for all the layers in the network. Each mask is essentially a matrix of ones.

        No masks are constructed for the biases

        Arguments
        -------
        model: (nn.Module), the feed forward network to prune
        """
        for n, m in model.named_children():
            if hasattr(m, 'weight'):
                prune.identity(m, name='weight')
                self.weight_mask[f"{n}.weight"] = m.weight_mask.detach().clone()
                prune.remove(m, name='weight')