Exemplo n.º 1
0
 def prune_actor(self, ratios, dims):
     if len(ratios) != 2:
         raise ValueError("length of ratios not matching critic number of layers")
     if len(dims) != 2:
         raise ValueError("length of ratios not matching critic number of layers")
     prune.ln_structured(self.actor[0], "weight", amount=ratios[0], n=1, dim=dims[0])
     prune.ln_structured(self.actor[2], "weight", amount=ratios[1], n=1, dim=dims[0])
Exemplo n.º 2
0
def prune_step(model, a1=0.01, a2=0.01, conv_group=True):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) and conv_group:
            prune.ln_structured(module, name='weight', amount=a1, n=2, dim=0)
            prune.ln_structured(module, name='weight', amount=a1, n=2, dim=1)

            #prune.remove(module, name='weight')
        elif isinstance(module, torch.nn.Linear) or isinstance(
                module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=a2)
Exemplo n.º 3
0
 def prune_module(module, method, amount):
     if method == "ln":
         prune.ln_structured(module,
                             name="weight",
                             amount=amount,
                             n=2,
                             dim=0)
     elif method == "l1":
         prune.l1_unstructured(module, name="weight", amount=amount)
     else:
         raise ValueError(f"{method} is wrong")
def prune_transformer_block(transformer_block, args):
    pruning_amount = float(args.pruning_amount)
    prune.ln_structured(transformer_block.fc1,
                        name='weight',
                        amount=pruning_amount,
                        n=0,
                        dim=0)
    prune.remove(transformer_block.fc1, 'weight')
    prune.ln_structured(transformer_block.fc2,
                        name='weight',
                        amount=pruning_amount,
                        n=0,
                        dim=0)
    prune.remove(transformer_block.fc2, 'weight')
    for sub_module in transformer_block.fc_delta:
        if isinstance(sub_module, torch.nn.Linear):
            prune.ln_structured(sub_module,
                                name='weight',
                                amount=pruning_amount,
                                n=0,
                                dim=0)
            prune.remove(sub_module, 'weight')
    for sub_module in transformer_block.fc_gamma:
        if isinstance(sub_module, torch.nn.Linear):
            prune.ln_structured(sub_module,
                                name='weight',
                                amount=pruning_amount,
                                n=0,
                                dim=0)
            prune.remove(sub_module, 'weight')
    return transformer_block
Exemplo n.º 5
0
 def prune(self):
     prune.ln_structured(self.conv1, name='weight', amount=self.sparsity, n=1, dim=0)
     prune.ln_structured(self.conv2, name='weight', amount=self.sparsity, n=1, dim=0)
     prune.ln_structured(self.conv3, name='weight', amount=self.sparsity, n=1, dim=0)
     if self.se != None:
         for c in self.se.se:
             if isinstance(c, nn.Conv2d):
                 prune.ln_structured(c, name='weight', amount=self.sparsity, n=1, dim=0)
Exemplo n.º 6
0
    def prune_all(self):
        for layer_idx in range(self.nb_layers):
            conv = eval(f"self.model.conv{layer_idx+1}")
            bn = eval(f"self.model.bn{layer_idx+1}")
            prune.ln_structured(module=conv,
                                name='weight',
                                amount=self.amount,
                                n=self.norme,
                                dim=self.dim)
            prune.l1_unstructured(module=bn, name='weight', amount=self.amount)

        prune.ln_structured(module=self.model.fc1,
                            name='weight',
                            amount=self.amount,
                            n=self.norme,
                            dim=self.dim)
Exemplo n.º 7
0
 def prune_all(self):
     prune.ln_structured(module=self.model.conv1,
                         name='weight',
                         amount=self.amount,
                         n=self.norme,
                         dim=self.dim)
     prune.l1_unstructured(module=self.model.bn1,
                           name='weight',
                           amount=self.amount)
     self.prune_block(self.model.layer1)
     self.prune_block(self.model.layer2)
     self.prune_block(self.model.layer3)
     prune.ln_structured(module=self.model.linear,
                         name='weight',
                         amount=self.amount,
                         n=self.norme,
                         dim=self.dim)
Exemplo n.º 8
0
 def prune_model(model):
     remove_amount = total_prune_amount / (max_epochs * 10)
     print(f'pruned model by {remove_amount}')
     if prune_type == 'global_unstructured':
         parameters_to_prune = [(layer, 'weight') for layer in conv_layers]
         prune.global_unstructured(
             parameters_to_prune,
             pruning_method=prune.L1Unstructured,
             amount=remove_amount,
         )
     else:
         for layer in conv_layers:
             prune.ln_structured(layer,
                                 name='weight',
                                 amount=remove_amount,
                                 n=1,
                                 dim=0)
Exemplo n.º 9
0
def channel_pruning(old_model,
                    pruning_option="l1",
                    name="weight",
                    prune_ratio=0.5):
    alive_weight_index = []

    # Get alived channel index via L1 Structured Pruning
    if pruning_option == "l1":
        for _, old_module in old_model.named_modules():
            if isinstance(old_module, torch.nn.Conv2d):
                prune.ln_structured(old_module,
                                    name=name,
                                    amount=prune_ratio,
                                    n=1,
                                    dim=0)
                alive_index = alive_channel_index(old_module)
                alive_weight_index.append(alive_index)

    # Get alived channel index via L2 Structured Pruning
    elif pruning_option == "l2":
        for _, old_module in old_model.named_modules():
            if isinstance(old_module, torch.nn.Conv2d):
                prune.ln_structured(old_module,
                                    name=name,
                                    amount=prune_ratio,
                                    n=2,
                                    dim=0)
                alive_index = alive_channel_index(old_module)
                alive_weight_index.append(alive_index)

    # Get alived cannel index via random Structured Pruning
    elif pruning_option == "random":
        for _, old_module in old_model.named_modules():
            if isinstance(old_module, nn.Conv2d):
                num_out_channel = old_module.weight.data.shape[0]
                num_pruned_channel = int(prune_ratio * num_out_channel)
                num_alive_channel = num_out_channel - num_pruned_channel
                alive_index = random.sample(range(num_out_channel),
                                            num_alive_channel)
                alive_weight_index.append(alive_index)

    return alive_weight_index
Exemplo n.º 10
0
def get_pruned_model(model, normed=False, amount=0.5):
    model_pruned = deepcopy(model)
    for i in range(len(model)):
        if isinstance(model_pruned[i], nn.Conv2d):
            prune.ln_structured(model_pruned[i],
                                name='weight',
                                amount=amount,
                                n=2,
                                dim=0)
            if normed:
                model_pruned[i].weight = model_pruned[i].weight / (1 - amount)
        elif isinstance(model_pruned[i], nn.Linear) and i != len(model) - 2:
            prune.ln_structured(model_pruned[i],
                                name='weight',
                                amount=amount,
                                n=2,
                                dim=0)
            if normed:
                model_pruned[i].weight = model_pruned[i].weight / (1 - amount)
    return model_pruned
Exemplo n.º 11
0
    def prune_by_percentile(self, amount=5.0):
        '''
        method to prune specified modules in layer to be pruned with 
        percentile threshold
        '''
        alive_parameters = []
        for name, p in self.named_parameters():
            if 'bias' in name or 'mask' in name:
                continue
            tensor = p.data.cpu().numpy()
            alive = tensor[np.nonzero(tensor)]
            alive_parameters.append(alive)

        all_alives = np.concatenate(alive_parameters)
        percentile_value = np.percentile(abs(all_alives), amount)
        logging.info(f'Pruning with threshold : {percentile_value}')

        for name, module in self.named_modules():
            if name.contains('denselayer'):
                prune.ln_structured(module, name='weight', amount=percentile_value, n=2, dim=0)
Exemplo n.º 12
0
def pruning_cp_fg(net, a_list):
    if not isinstance(net, nn.Module):
        print('Invalid input. Must be nn.Module')
        return
    newnet = copy.deepcopy(net)
    i = 0
    for name, module in newnet.named_modules():
        if isinstance(module, nn.Conv2d):
            # print("Sparsity ratio",a_list[i])
            prune.ln_structured(module,
                                name='weight',
                                amount=float(1 - a_list[i]),
                                n=2,
                                dim=0)
            i += 1
        if isinstance(module, nn.Linear):
            prune.l1_unstructured(module,
                                  name='weight',
                                  amount=float(1 - a_list[i]))
            i += 1
    return newnet
Exemplo n.º 13
0
 def prune_block(self, sub_layer):
     for block_num, block in enumerate(sub_layer):
         prune.ln_structured(module=block.conv1,
                             name='weight',
                             amount=self.amount,
                             n=self.norme,
                             dim=self.dim)
         prune.l1_unstructured(module=block.bn1,
                               name='weight',
                               amount=self.amount)
         prune.ln_structured(module=block.conv2,
                             name='weight',
                             amount=self.amount,
                             n=self.norme,
                             dim=self.dim)
         prune.l1_unstructured(module=block.bn2,
                               name='weight',
                               amount=self.amount)
         for short_layer in block.shortcut:
             if isinstance(short_layer, torch.nn.modules.conv.Conv1d):
                 prune.ln_structured(module=short_layer,
                                     name='weight',
                                     amount=self.amount,
                                     n=self.norme,
                                     dim=self.dim)
             elif isinstance(short_layer,
                             torch.nn.modules.batchnorm.BatchNorm1d):
                 prune.l1_unstructured(module=short_layer,
                                       name='weight',
                                       amount=self.amount)
Exemplo n.º 14
0
def channel_pruning(net, a_list):
    '''
    :param net: DNN
    :param a_list: pruning rate
    :return: newnet (nn.Module): a newnet contain mask that help prune network's weight
    '''

    if not isinstance(net, nn.Module):
        print('Invalid input. Must be nn.Module')
        return
    newnet = copy.deepcopy(net)
    i = 0
    for name, module in newnet.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.ln_structured(module,
                                name='weight',
                                amount=float(1 - a_list[i]),
                                n=2,
                                dim=0)
            i += 1

    return newnet
Exemplo n.º 15
0
def main():
    config = '/media/shalev/98a3e66d-f664-402a-9639-15ec6b8a7150/work_dirs/try2/faster_rcnn_r50_caffe_c4_1x_coco_shalev.py'
    checkpoint = '/media/shalev/98a3e66d-f664-402a-9639-15ec6b8a7150/work_dirs/try2/latest.pth'
    src_img_path = '/home/shalev/downloads/1pic_coco/000000000285.jpg'
    dst_img_path = '/home/shalev/downloads/1pic_coco/000000000285_res.jpg'
    img = cv2.imread(src_img_path)
    model = mmdet.apis.init_detector(config,
                                     checkpoint=checkpoint,
                                     device='cuda:0')
    for i in range(10):
        if PRUNE:
            # backbone = model.backbone
            modules = [
                model.backbone.children(),
                model.roi_head.children(),
                model.rpn_head.children()
            ]
            for main_module in modules:
                for module in main_module:
                    if isinstance(module, torch.nn.Conv2d) or isinstance(
                            module, torch.nn.Linear):
                        print("before: ", module.weight.sum())
                        prune.ln_structured(module,
                                            name='weight',
                                            amount=0.05,
                                            dim=0,
                                            n=float('-inf'))
                        print("after: ", module.weight.sum())
                    else:
                        for sub in module.children():
                            if isinstance(sub, torch.nn.Conv2d) or isinstance(
                                    sub, torch.nn.Linear):
                                print("before: ", sub.weight.sum())
                                prune.ln_structured(sub,
                                                    name='weight',
                                                    amount=0.15,
                                                    dim=0,
                                                    n=float('-inf'))
                                print("after: ", sub.weight.sum())
                            else:
                                for sub_sub in sub.children():
                                    if isinstance(
                                            sub_sub,
                                            torch.nn.Conv2d) or isinstance(
                                                sub_sub, torch.nn.Linear):
                                        print("before: ", sub_sub.weight.sum())
                                        prune.ln_structured(sub_sub,
                                                            name='weight',
                                                            amount=0.15,
                                                            dim=0,
                                                            n=float('-inf'))
                                        print("after: ", sub_sub.weight.sum())

        start = time.time()
        res = mmdet.apis.inference_detector(model, img)
        print("Inference time: ", (time.time() - start))
        if hasattr(model, 'module'):
            model = model.module
        img_res = model.show_result(img, res, score_thr=0.305, show=True)
Exemplo n.º 16
0
 def prune_critic(self, ratios, dims):
     if len(ratios) != 3:
         raise ValueError("length of ratios not matching critic number of layers")
     if len(dims) != 3:
         raise ValueError("length of ratios not matching critic number of layers")
     prune.ln_structured(self.critic[0], "weight", amount=ratios[0], n=1, dim=dims[0])
     prune.ln_structured(self.critic[2], 'weight', amount=ratios[1], n=1, dim=dims[1])
     prune.ln_structured(self.critic_linear, "weight", amount=ratios[2], n=2, dim=dims[2])
 def _prune_res_unit5(self, ratio=0.1):
     prune.ln_structured(list(self.conv5_x)[0].conv1,
                         name="weight",
                         amount=ratio,
                         n=1,
                         dim=0)
     prune.ln_structured(list(self.conv5_x)[0].conv2,
                         name="weight",
                         amount=ratio,
                         n=1,
                         dim=0)
     prune.ln_structured(list(self.conv5_x)[1].conv1,
                         name="weight",
                         amount=ratio,
                         n=1,
                         dim=0)
     prune.ln_structured(list(self.conv5_x)[1].conv2,
                         name="weight",
                         amount=ratio,
                         n=1,
                         dim=0)
Exemplo n.º 18
0
                       weight_bit_num=weight_bit_width)
        # end = time.time()
        # print(f'It takes {end-start:.6f} seconds.')

        net_gpu.load_state_dict(
            torch.load('weight_bit_' + str(weight_bit_width) + '_best.pth'))
        net_gpu.to(device)

        resume_acc = evaluate(net_gpu, xtest_gpu, ytest)
        print('quantization best accuracy: {:.5f}'.format(resume_acc))

        #prune
        conv_module = net_gpu.conv2
        prune.ln_structured(conv_module,
                            name='weight',
                            amount=amount_num,
                            n=2,
                            dim=0)

        # print(list(model.conv2.named_parameters()))
        #
        prune_acc = evaluate(net_gpu, xtest_gpu, ytest)

        print('prune accuracy: {:.5f}'.format(prune_acc))

        fine_tune_epoch = 20
        train_and_eval(xtrain_gpu,
                       ytrain_gpu,
                       net_gpu,
                       xtest_gpu,
                       ytest,
Exemplo n.º 19
0
    def pruning(model0, percentage, method):
        # copy a model0 for pruning
        # model0=copy.deepcopy(model.to(device))
        if method == "unstructured":
            for name, module in model0.named_modules():
                if isinstance(module, torch.nn.Embedding):
                    prune.l1_unstructured(module,
                                          name='weight',
                                          amount=percentage)
                # prune lstm layers
                elif isinstance(module, torch.nn.LSTM):
                    prune.l1_unstructured(module,
                                          name='weight_hh_l0',
                                          amount=percentage)
                    prune.l1_unstructured(module,
                                          name='weight_ih_l0',
                                          amount=percentage)
                    prune.l1_unstructured(module,
                                          name='weight_hh_l1',
                                          amount=percentage)
                    prune.l1_unstructured(module,
                                          name='weight_ih_l1',
                                          amount=percentage)
                # prune  linear layers
                elif isinstance(module, torch.nn.Linear):
                    prune.l1_unstructured(module,
                                          name='weight',
                                          amount=percentage)
        elif method == "structured":
            for name, module in model0.named_modules():
                if isinstance(module, torch.nn.Embedding):
                    prune.ln_structured(module,
                                        name='weight',
                                        amount=percentage,
                                        n=1,
                                        dim=0)
                # prune lstm layers
                elif isinstance(module, torch.nn.LSTM):
                    prune.ln_structured(module,
                                        name='weight_hh_l0',
                                        amount=percentage,
                                        n=1,
                                        dim=0)
                    prune.ln_structured(module,
                                        name='weight_ih_l0',
                                        amount=percentage,
                                        n=1,
                                        dim=0)
                    prune.ln_structured(module,
                                        name='weight_hh_l1',
                                        amount=percentage,
                                        n=1,
                                        dim=0)
                    prune.ln_structured(module,
                                        name='weight_ih_l1',
                                        amount=percentage,
                                        n=1,
                                        dim=0)
                # prune  linear layers
                elif isinstance(module, torch.nn.Linear):
                    prune.ln_structured(module,
                                        name='weight',
                                        amount=percentage,
                                        n=1,
                                        dim=0)
        for name, module in model0.named_modules():
            if isinstance(module, torch.nn.Embedding):
                prune.remove(module, 'weight')
            # prune  lstm layers
            elif isinstance(module, torch.nn.LSTM):
                prune.remove(module, 'weight_hh_l0')
                prune.remove(module, 'weight_ih_l0')
                prune.remove(module, 'weight_hh_l1')
                prune.remove(module, 'weight_ih_l1')
            # prune  linear layers
            elif isinstance(module, torch.nn.Linear):
                prune.remove(module, 'weight')

        test_data = TensorDataset(torch.Tensor(X), torch.Tensor(Y).long())
        test_loader = DataLoader(dataset=test_data,
                                 batch_size=batch_size,
                                 shuffle=True)
        # model_prune=copy.deepcopy(model0.to(device))
        model_prune = model0
        optimizer = torch.optim.Adam(model_prune.parameters(), lr)
        model_ls = []
        accuracy = []
        for epoch in range(num_epochs_retrain):
            for i, (x, y) in enumerate(train_loader):
                x = x.reshape(-1, time_steps, input_size).to(device)
                y = y.to(device)

                # forward pass
                outputs = model_prune(x)
                loss = criterion(outputs, y)

                # backward and optimize
                optimizer.zero_grad()
                loss.backward()

                for name, param in model_prune.named_parameters():
                    if "weight" in name:
                        param_data = param.data.cpu().numpy()
                        param_grad = param.grad.data.cpu().numpy()
                        param_grad = np.where(param_data < 0.00001, 0,
                                              param_grad)
                        param.grad.data = torch.from_numpy(param_grad).to(
                            device)

                optimizer.step()

                if i % 1000 == 0:
                    print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                        epoch + 1, num_epochs_retrain, i + 1, total_step,
                        loss.item()))

                    with torch.no_grad():
                        correct = 0
                        total = 0
                        times = 0
                        for x, y in test_loader:
                            x = x.reshape(-1, time_steps,
                                          input_size).to(device)
                            y = y.to(device)
                            outputs = model_prune(x)
                            prob = F.softmax(outputs)
                            _, predicted = torch.max(prob, 1)
                            total += y.size(0)
                            correct += (predicted == y).sum().item()
                            times = times + 1
                            if times > 100:
                                break

                        print(
                            'Test Accuracy of the model on the 10000 test x: {} %'
                            .format(100 * correct / total))
                    model_ls.append(model_prune)
                    accuracy.append(100 * correct / total)

        model_prune = model_ls[np.argmax(accuracy)]

        # test accuary
        test_data = TensorDataset(torch.Tensor(X), torch.Tensor(Y).long())
        test_loader = DataLoader(dataset=test_data,
                                 batch_size=batch_size,
                                 shuffle=True)
        with torch.no_grad():
            correct = 0
            total = 0
            times = 0
            for x, y in test_loader:
                x = x.reshape(-1, time_steps, input_size).to(device)
                y = y.to(device)
                outputs = model_prune(x)
                prob = F.softmax(outputs)
                _, predicted = torch.max(prob, 1)
                total += y.size(0)
                correct += (predicted == y).sum().item()
                times = times + 1
                if times > 100:
                    break

        print('Test Accuracy of the model on the 10000 test x: {} %'.format(
            100 * correct / total))

        # quantize the pruned model
        quantized_model_prune = torch.quantization.quantize_dynamic(
            model_prune.to('cpu'), {nn.Embedding, nn.LSTM, nn.Linear},
            dtype=torch.qint8)

        return model_prune, quantized_model_prune
Exemplo n.º 20
0
def prune_model(method_name, parameters_to_prune, pruning_rate):
    if method_name == 'l1_unstructured':
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=pruning_rate,
        )

    elif method_name == 'random':
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.RandomUnstructured,
            amount=pruning_rate,
        )

    elif method_name == 'l1_structured':
        for (module, name) in parameters_to_prune:
            prune.ln_structured(module=module,
                                name=name,
                                n=2,
                                amount=pruning_rate,
                                dim=-1)
            prune.ln_structured(module=module,
                                name=name,
                                n=2,
                                amount=pruning_rate,
                                dim=-1)
            prune.ln_structured(module=module,
                                name=name,
                                n=2,
                                amount=pruning_rate,
                                dim=-1)

    elif method_name == 'l2_structured':
        for (module, name) in parameters_to_prune:
            prune.ln_structured(module=module,
                                name=name,
                                n=2,
                                amount=pruning_rate,
                                dim=-1)
            prune.ln_structured(module=module,
                                name=name,
                                n=2,
                                amount=pruning_rate,
                                dim=-1)
            prune.ln_structured(module=module,
                                name=name,
                                n=2,
                                amount=pruning_rate,
                                dim=-1)

    else:
        raise ("Pruning method not found")
Exemplo n.º 21
0
print("# conv1 pruned buffers")
print(list(module.named_buffers()))

# Prune weight using L1 norm and 3 smallest entries

prune.l1_unstructured(module, name="bias", amount=3)
print("# conv1 pruned bias params")
print(list(module.named_parameters()))
print("# conv1 pruned bias buffers")
print(list(module.named_buffers()))
print("# Forward pre hooks")
print(module._forward_pre_hooks)

# Iterative pruning (Prune multiple times in series, zeros out 50%)

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# weights pruned
print(module.weight)
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

# pruning history
print(list(hook))

# Remove reparamatrization
prune.remove(module, "weight")
print(list(module.named_parameters()))

# Prune multiple based on type (20% in conv and 40% in linear)
Exemplo n.º 22
0
def train(opt,Gs,Zs,reals,NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_,reals,opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale+1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
                pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' %  (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1)

        D_curr,G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev==opt.nfc):
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        if fine_tune:
          z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, warmup_steps)
        else:
          z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter)


        G_curr = functions.reset_grads(G_curr,False)
        # D_curr = functions.reset_grads(D_curr,False)
        G_curr.eval()
        # D_curr.eval()

        #################################################################################
        # Visualzie weights
        def visualize_weights(modules, fig_name):
            ori_weights = torch.tensor([]).cuda()
            for m in modules:
                cur_params = m.weight.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
                # cur_params = m.bias.data.flatten()
                # ori_weights = torch.cat((ori_weights, cur_params))
            sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement())
            print(sparsity, ori_weights.nelement())
            ori_weights = ori_weights.cpu().numpy()
            ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100)
            plt.savefig("%s/%s.png" % (opt.outf, fig_name))
            plt.close()

        # Pruning weights Structured or Non-structured
        if not structured:
            modules = [G_curr.head.conv, G_curr.head.norm,
                    G_curr.body.block1.conv, G_curr.body.block1.norm,
                    G_curr.body.block2.conv, G_curr.body.block2.norm,
                    G_curr.body.block3.conv, G_curr.body.block3.norm,
                    G_curr.tail[0]]
            parameters_to_prune = (
                (G_curr.head.conv, 'weight'),
                (G_curr.head.norm, 'weight'),
                (G_curr.body.block1.conv, 'weight'),
                (G_curr.body.block1.norm, 'weight'),
                (G_curr.body.block2.conv, 'weight'),
                (G_curr.body.block2.norm, 'weight'),
                (G_curr.body.block3.conv, 'weight'),
                (G_curr.body.block3.norm, 'weight'),
                (G_curr.tail[0], 'weight'),
                (G_curr.head.conv, 'bias'),
                (G_curr.head.norm, 'bias'),
                (G_curr.body.block1.conv, 'bias'),
                (G_curr.body.block1.norm, 'bias'),
                (G_curr.body.block2.conv, 'bias'),
                (G_curr.body.block2.norm, 'bias'),
                (G_curr.body.block3.conv, 'bias'),
                (G_curr.body.block3.norm, 'bias'),
                (G_curr.tail[0], 'bias'),
            )

            visualize_weights(modules, 'ori')

            # Prune weights
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=pruning_amount,
            )
        else:
            modules = [G_curr.head.conv,
            G_curr.body.block1.conv,
            G_curr.body.block2.conv,
            G_curr.body.block3.conv]

            visualize_weights(modules, 'ori')
            # pytorch_total_params = sum(p.numel() for p in G_curr.parameters())
            # print(pytorch_total_params)

            for module in modules:
                m = prune.ln_structured(module, name="weight", amount=pruning_amount, n=1, dim=0)
                # m = prune.ln_structured(module, name="bias", amount=pruning_amount, n=1, dim=0)

        torch.save(G_curr.state_dict(), '%s/raw_prune_netG.pth' % (opt.outf))
        visualize_weights(modules, 'raw-prune')
        if cur_scale_level > 0:
            fake_Gs = Gs.copy()
            fake_Gs.append(G_curr) 
            fake_Zs = Zs.copy()
            fake_Zs.append(z_curr)
            fake_noise = NoiseAmp.copy()
            fake_noise.append(opt.noise_amp)
            fake_reals = reals[:cur_scale_level+1].copy()
            prune_SinGAN_generate(fake_Gs, fake_Zs, fake_reals, fake_noise, opt, gen_start_scale=0, num_samples=1, level=cur_scale_level)

        # Fine-tuning
        if fine_tune:
            G_curr = functions.reset_grads(G_curr, True)
            G_curr.train()

            if not structured:
                # Keep training using inherited weights
                z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter - warmup_steps, prune=True)
            else:
                # Training from scratch
                # G_curr.apply(models.weights_init)
                # D_curr.apply(models.weights_init)
                z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter, prune=True)
        G_curr = functions.reset_grads(G_curr,False)
        G_curr.eval()
        visualize_weights(modules, 'fine-tune')

        for m in modules:
            prune.remove(m, 'weight')
            if not structured:
              prune.remove(m, 'bias')
        
        # pytorch_total_params = sum(p.numel() for p in G_curr.parameters())
        # print(pytorch_total_params)

        #################################################################################
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level+=1
        nfc_prev = opt.nfc
        del D_curr,G_curr
    return
Exemplo n.º 23
0
    Module.append(newModel.features[i])
    frac = length - count
    pr.append(maxpr / frac)
    with open(outfile, 'a') as f:
        f.write("Global Sparsity: {:.2f}%".format(pr[count] * 100))
    count += 1
"""Select the amount of feature we want to prune in each Layer"""
epochs = 10
max_lr = 1e-3
grad_clip = .2
weight_decay = 1e-5
L1 = 1e-5
itteration = 5
for ittr in range(itteration):
    for i in range(len(Module)):
        prune.ln_structured(Module[i], name="weight", amount=pr[i], n=1, dim=0)
    numberOfZero = 0
    numberOfElements = 0
    totalNumberOfZero = 0
    totalNumberOfElements = 0
    for i, j in zip(range(len(prunelist)), prunelist):
        numberOfZero = torch.sum(Module[i].weight == 0)
        totalNumberOfZero += numberOfZero
        numberOfElements = Module[i].weight.nelement()
        totalNumberOfElements += numberOfElements
        frac = 100. * float(torch.sum(Module[i].weight == 0)) / float(
            Module[i].weight.nelement())
        with open(outfile, 'a') as f:
            f.write(f"\n {j} Sparsity in {Module[i]} is \t{frac}")

    with open(outfile, 'a') as f:
Exemplo n.º 24
0
 def prune_actor(self, ratios, dims):
     # if type(self.dist) is Categorical:
     prune.ln_structured(self.dist.linear, "weight", amount=ratios[-1], n=1, dim=dims[-1])
     # if type(self.dist) is DiagGaussian:
     # prune.ln_structured(self.dist.fc_mean, "weight", amount=ratios[-1], n=1, dim=dims[-1])
     self.base.prune_actor(ratios[:-1], dims[:-1])
Exemplo n.º 25
0
 def __init__(self,
              in_planes,
              planes,
              num_bits,
              num_bits_weight,
              stride,
              type_prune,
              sparsity,
              layer_num,
              option='A'):
     super(ResNetBlock, self).__init__()
     if in_planes == 3:
         op = QConv2d(in_planes,
                      planes,
                      num_bits,
                      num_bits_weight,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=False)
         if type_prune == 'channel':
             op = prune.ln_structured(op,
                                      name='weight',
                                      amount=sparsity,
                                      n=2,
                                      dim=0)
         elif type_prune == 'group':
             width = 4
             tmp_pruned = op.weight.data.clone()
             original_size = tmp_pruned.size()
             tmp_pruned = tmp_pruned.view(original_size[0], -1)
             append_size = width - tmp_pruned.shape[1] % width
             tmp_pruned = torch.cat(
                 (tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
             tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width)
             tmp_pruned = tmp_pruned.pow(2.0).mean(
                 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape)
             tmp = tmp_pruned.flatten()
             num = tmp.shape[0] * (1 - sparsity)
             top_k = torch.topk(tmp, int(num), sorted=True)
             threshold = top_k.values[-1]
             tmp_pruned = tmp_pruned.ge(threshold)
             tmp_pruned = tmp_pruned.view(original_size[0], -1)
             tmp_pruned = tmp_pruned[:, 0:op.weight.data[0].nelement()]
             tmp_pruned = tmp_pruned.contiguous().view(original_size)
             op = prune.custom_from_mask(op, name='weight', mask=tmp_pruned)
         self.add_module("conv", op)
         bn_op = nn.BatchNorm2d(planes)
         self.add_module("bn", bn_op)
         self.add_module("relu", nn.ReLU(inplace=True))
     elif in_planes == 1:
         self.add_module("avg_pool", nn.AvgPool2d(kernel_size=8, stride=1))
         self.add_module("flatten", Flatten())
         op = nn.Linear(in_features=64, out_features=10)
         if type_prune == 'channel':
             op = prune.ln_structured(op,
                                      name='weight',
                                      amount=sparsity,
                                      n=2,
                                      dim=0)
         elif type_prune == 'group':
             width = 4
             tmp_pruned = op.weight.data.clone()
             original_size = tmp_pruned.size()
             tmp_pruned = tmp_pruned.view(original_size[0], -1)
             append_size = width - tmp_pruned.shape[1] % width
             tmp_pruned = torch.cat(
                 (tmp_pruned, tmp_pruned[:, 0:append_size]), 1)
             tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width)
             tmp_pruned = tmp_pruned.pow(2.0).mean(
                 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape)
             tmp = tmp_pruned.flatten()
             num = tmp.shape[0] * (1 - sparsity)
             top_k = torch.topk(tmp, int(num), sorted=True)
             threshold = top_k.values[-1]
             tmp_pruned = tmp_pruned.ge(threshold)
             tmp_pruned = tmp_pruned.view(original_size[0], -1)
             tmp_pruned = tmp_pruned[:, 0:op.weight.data[0].nelement()]
             tmp_pruned = tmp_pruned.contiguous().view(original_size)
             op = prune.custom_from_mask(op, name='weight', mask=tmp_pruned)
         self.add_module("fc", op)
     else:
         op = BasicBlock(in_planes, planes, num_bits, num_bits_weight,
                         stride, type_prune, sparsity, layer_num, option)
         self.add_module("conv", op)
 def global_pruning(self,p_to_delete,dim=0):
     for target_module in self.target_modules:
         prune.ln_structured(target_module,name="weight",dim=dim,amount=p_to_delete,n=1) # dim est là où on veut supprimer poids (ligne : 1, col : 0?) Sur quelle dim c'est mieux de pruner?
Exemplo n.º 27
0
    def __init__(self,
                 in_planes,
                 planes,
                 num_bits,
                 num_bits_weight,
                 stride,
                 type_prune,
                 sparsity,
                 layer_num,
                 option='A'):
        super(BasicBlock, self).__init__()

        self.conv1 = QConv2d(in_planes,
                             planes,
                             num_bits,
                             num_bits_weight,
                             kernel_size=3,
                             stride=stride,
                             padding=1,
                             bias=False)

        if type_prune == 'channel':
            self.conv1 = prune.ln_structured(self.conv1,
                                             name='weight',
                                             amount=sparsity,
                                             n=2,
                                             dim=0)
        elif type_prune == 'group':
            width = 4
            tmp_pruned = self.conv1.weight.data.clone()
            original_size = tmp_pruned.size()
            tmp_pruned = tmp_pruned.view(original_size[0], -1)
            append_size = width - tmp_pruned.shape[1] % width
            tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]),
                                   1)
            tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width)
            tmp_pruned = tmp_pruned.pow(2.0).mean(
                2, keepdim=True).pow(0.5).expand(tmp_pruned.shape)
            tmp = tmp_pruned.flatten()
            num = tmp.shape[0] * (1 - sparsity)
            top_k = torch.topk(tmp, int(num), sorted=True)
            threshold = top_k.values[-1]
            tmp_pruned = tmp_pruned.ge(threshold)
            tmp_pruned = tmp_pruned.view(original_size[0], -1)
            tmp_pruned = tmp_pruned[:, 0:self.conv1.weight.data[0].nelement()]
            tmp_pruned = tmp_pruned.contiguous().view(original_size)
            self.conv1 = prune.custom_from_mask(self.conv1,
                                                name='weight',
                                                mask=tmp_pruned)

        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = QConv2d(planes,
                             planes,
                             num_bits,
                             num_bits_weight,
                             kernel_size=3,
                             stride=1,
                             padding=1,
                             bias=False)
        if type_prune == 'channel':
            self.conv2 = prune.ln_structured(self.conv2,
                                             name='weight',
                                             amount=sparsity,
                                             n=2,
                                             dim=0)
        elif type_prune == 'group':
            width = 4
            tmp_pruned = self.conv2.weight.data.clone()
            original_size = tmp_pruned.size()
            tmp_pruned = tmp_pruned.view(original_size[0], -1)
            append_size = width - tmp_pruned.shape[1] % width
            tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]),
                                   1)
            tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width)
            tmp_pruned = tmp_pruned.pow(2.0).mean(
                2, keepdim=True).pow(0.5).expand(tmp_pruned.shape)
            tmp = tmp_pruned.flatten()
            num = tmp.shape[0] * (1 - sparsity)
            top_k = torch.topk(tmp, int(num), sorted=True)
            threshold = top_k.values[-1]
            tmp_pruned = tmp_pruned.ge(threshold)
            tmp_pruned = tmp_pruned.view(original_size[0], -1)
            tmp_pruned = tmp_pruned[:, 0:self.conv2.weight.data[0].nelement()]
            tmp_pruned = tmp_pruned.contiguous().view(original_size)
            self.conv2 = prune.custom_from_mask(self.conv2,
                                                name='weight',
                                                mask=tmp_pruned)

        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x: F.pad(
                    x[:, :, ::2, ::2],
                    (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                    QConv2d(in_planes,
                            self.expansion * planes,
                            num_bits,
                            num_bits_weight,
                            kernel_size=1,
                            stride=stride,
                            bias=False),
                    nn.BatchNorm2d(self.expansion * planes))
Exemplo n.º 28
0
    def __init__(self, args, config):

        # Init arguments
        self.args = args
        self.config = config
        self.config.experiment_start_time = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.config.experiment_name = self.experiment_name
        if not self.args.dry_run:
            self.checkpoints_dir_path = os.path.join(self.output_dir_path, 'checkpoints')
            self.setup_experiment_output()
        self.logger = Logger(args, config, self.output_dir_path)
        self.logger.log_config()

        # Randomness
        random.seed(config.random_seed)
        torch.manual_seed(config.random_seed)
        torch.cuda.manual_seed_all(config.random_seed)

        # Datasets
        if config.dataset == 'CIFAR10' or config.dataset == 'CIFAR100':
            self.input_size = 32
            train_loader, test_loader, num_classes = cifar_dataloader.get_loaders(config.dataset,
                                                                                  args.datadir,
                                                                                  args.batch_size,
                                                                                  args.num_workers)
            self.train_loader = train_loader
            self.test_loader = test_loader
            self.num_classes = num_classes

        elif config.dataset == 'GTSRB':
            self.input_size = 32
            train_loader, val_loader, test_loader, num_classes = gtsrb_dataloader.cf_gtsrb.get_loaders(args.datadir,
                                                                                  args.batch_size,
                                                                                  args.num_workers)

            self.train_loader = train_loader
            self.val_loader = val_loader
            self.test_loader = test_loader
            self.num_classes = num_classes

        else:
            raise Exception("Dataset not supported: {}".format(config.dataset))

        # Init starting values
        self.starting_epoch = 1
        self.best_val_acc = 0

        # Setup device
        if self.args.gpus is not None:
            self.args.gpus = [int(i) for i in self.args.gpus.split(',')]
            self.device = 'cuda:' + str(args.gpus[0])
            torch.backends.cudnn.benchmark = True
        else:
            self.device = 'cpu'
        self.device = torch.device(self.device)

        # Setup model
        model = get_model(self.config, self.num_classes, self.input_size)

        # Resume model, if any
        if args.resume:
            print('Loading model checkpoint at: {}'.format(args.resume))
            package = torch.load(args.resume, map_location=self.device)
            model_state_dict = package['state_dict']
            #model_state_dict = utils.state_dict_retrocompatibility(model_state_dict)
            model.load_state_dict(model_state_dict, strict=args.strict)

        if args.pruned_retrain:
            for name, module in model.named_modules():
                # prune 20% of connections in all 2D-conv layers
                if isinstance(module, torch.nn.Conv2d):
                    prune.ln_structured(module, name='weight', amount=0.7, n=1, dim=0)

        self.model = model.to(device=self.device)
        if self.args.gpus is not None and len(self.args.gpus) > 1:
            self.model = nn.DataParallel(self.model, self.args.gpus)

        #Loss function
        self.criterion = nn.CrossEntropyLoss()
        self.criterion = self.criterion.to(device=self.device)

        # Init optimizer
        self.optimizer = self.model  # setter syntax

        # Resume optimizer, if any
        if args.resume and not args.evaluate and not args.pruned_retrain:
            self.logger.log.info("Loading optimizer checkpoint")
            if 'optim_dict' in package.keys():
                self.optimizer.load_state_dict(package['optim_dict'])
            if 'epoch' in package.keys():
                self.starting_epoch = package['epoch']

        # LR scheduler
        self.scheduler = self.optimizer  # setter syntax

        # Resume scheduler, if any
        if args.resume \
                and not args.evaluate \
                and self.scheduler is not None and 'epoch' in package.keys():
            self.scheduler.last_epoch = package['epoch'] - 1

        # Recap
        self.logger.log_cmd_args()
 def _prune_res_unit1(self, ratio=0.1):
     prune.ln_structured(list(self.conv1)[0],
                         name="weight",
                         amount=ratio,
                         n=1,
                         dim=0)
Exemplo n.º 30
0
def prune_model(model, prune_protopyte):
    model = copy.deepcopy(model)
    prune_protopyte = copy.deepcopy(prune_protopyte)

    for idx, (data_1, data_2) in enumerate(
            zip(model.named_modules(), prune_protopyte.named_modules())):
        if idx == 0:
            continue

        name_1, module_1 = data_1[0], data_1[1]
        name_2, module_2 = data_2[0], data_2[1]

        if isinstance(module_1, nn.Conv2d) or isinstance(module_1, nn.Linear):
            w_shape_1 = torch.tensor(module_1.weight.shape)
            w_shape_2 = torch.tensor(module_2.weight.shape)
            w_diff = torch.abs(w_shape_1 - w_shape_2)

            if w_diff[0] > 0 or w_diff[1] > 0:
                if w_diff[0] > 0:
                    prune.ln_structured(module_1,
                                        name="weight",
                                        amount=int(w_diff[0].item()),
                                        n=1,
                                        dim=0)

                if w_diff[1] > 0:
                    prune.ln_structured(module_1,
                                        name="weight",
                                        amount=int(w_diff[1].item()),
                                        n=1,
                                        dim=1)

                mask = module_1.weight_mask
                w = torch.where(mask != 0)
                w_mask = torch.unique(w[0])
                module_1.register_parameter('w_mask',
                                            nn.Parameter(w_mask.float()))

            continue

        if isinstance(module_1, nn.BatchNorm2d):
            w_shape_1 = torch.tensor(module_1.weight.shape)
            w_shape_2 = torch.tensor(module_2.weight.shape)
            w_diff = torch.abs(w_shape_1 - w_shape_2)

            if w_diff[0] > 0:
                prune.l1_unstructured(module_1, name="weight", amount=1.0)

    tree = []
    tree_dict = {}
    for idx, (name, module) in enumerate(model.named_modules()):
        if idx == 0:
            continue

        if isinstance(module, nn.Conv2d):
            tree.append([name, 'Conv2d'])
            tree_dict[name] = 'Conv2d'

        if isinstance(module, nn.BatchNorm2d):
            tree.append([name, 'BatchNorm2d'])
            tree_dict[name] = 'BatchNorm2d'

        if isinstance(module, nn.Linear):
            tree.append([name, 'Linear'])
            tree_dict[name] = 'Linear'

    bn_dependencies = {}
    for idx, t in enumerate(tree):
        if t[1] == 'BatchNorm2d' and idx == 0:
            raise Exception('ERROR')

        if t[1] == 'BatchNorm2d':
            bn_dependencies[t[0]] = tree[idx - 1][0]

    prune_protopyte_state_dict = prune_protopyte.state_dict()
    for key in prune_protopyte.state_dict().keys():
        prune_protopyte_state_dict[key].fill_(0)

    for layer in tree_dict.keys():
        if f'{layer}.weight_orig' in model.state_dict().keys(
        ) and f'{layer}.weight_mask' in model.state_dict().keys():
            if tree_dict[f'{layer}'] in ['Conv2d', 'Linear']:
                weights = model.state_dict()[f'{layer}.weight_orig']
                mask = model.state_dict()[f'{layer}.weight_mask']

                prune_protopyte_state_dict[f'{layer}.weight'] = weights[
                    mask.bool()].reshape(
                        prune_protopyte_state_dict[f'{layer}.weight'].shape)

                if f'{layer}.bias' in model.state_dict().keys():
                    bias = model.state_dict()[f'{layer}.bias']
                    w_mask = model.state_dict()[f'{layer}.w_mask'].long()
                    prune_protopyte_state_dict[f'{layer}.bias'] = bias[
                        w_mask].reshape(
                            prune_protopyte_state_dict[f'{layer}.bias'].shape)
                continue

            if tree_dict[f'{layer}'] == 'BatchNorm2d':
                weights = model.state_dict()[f'{layer}.weight_orig']
                running_mean = model.state_dict()[f'{layer}.running_mean']
                running_var = model.state_dict()[f'{layer}.running_var']

                w_mask = model.state_dict(
                )[f'{bn_dependencies[layer]}.w_mask'].long()

                prune_protopyte_state_dict[f'{layer}.weight'] = weights[
                    w_mask].reshape(
                        prune_protopyte_state_dict[f'{layer}.weight'].shape)
                prune_protopyte_state_dict[
                    f'{layer}.running_mean'] = running_mean[w_mask].reshape(
                        prune_protopyte_state_dict[f'{layer}.running_mean'].
                        shape)
                prune_protopyte_state_dict[
                    f'{layer}.running_var'] = running_var[w_mask].reshape(
                        prune_protopyte_state_dict[f'{layer}.running_var'].
                        shape)

                if f'{layer}.bias' in model.state_dict().keys():
                    bias = model.state_dict()[f'{layer}.bias']
                    prune_protopyte_state_dict[f'{layer}.bias'] = bias[
                        w_mask].reshape(
                            prune_protopyte_state_dict[f'{layer}.bias'].shape)
                continue
        else:
            if tree_dict[f'{layer}'] in ['Conv2d', 'Linear']:
                prune_protopyte_state_dict[
                    f'{layer}.weight'] = model.state_dict()[f'{layer}.weight']
                if f'{layer}.bias' in model.state_dict().keys():
                    prune_protopyte_state_dict[
                        f'{layer}.bias'] = model.state_dict()[f'{layer}.bias']

            if tree_dict[f'{layer}'] in ['Batch', 'BatchNorm2d']:
                prune_protopyte_state_dict[
                    f'{layer}.weight'] = model.state_dict()[f'{layer}.weight']
                prune_protopyte_state_dict[
                    f'{layer}.running_mean'] = model.state_dict(
                    )[f'{layer}.running_mean']
                prune_protopyte_state_dict[
                    f'{layer}.running_var'] = model.state_dict(
                    )[f'{layer}.running_var']

                if f'{layer}.bias' in model.state_dict().keys():
                    prune_protopyte_state_dict[
                        f'{layer}.bias'] = model.state_dict()[f'{layer}.bias']

    prune_protopyte.load_state_dict(prune_protopyte_state_dict)
    return prune_protopyte