Пример #1
0
def compute_pindex(modules, percentage, threshold):
    network_pindex_decom = []
    network_norm_decom = []
    network_pindex_prune = []
    network_norm_prune = []
    for m in modules:
        if isinstance(m, Dense):
            conv2 = m._modules['body']._modules['3']
        elif isinstance(m, Transition):
            conv2 = m._modules['3']
        else:
            raise NotImplementedError('Do not need to prune the layer ' +
                                      m.__class__.__name__)
        weight = conv2.weight.squeeze().t()
        norm_decom, pindex_decom = get_nonzero_index(weight,
                                                     dim='input',
                                                     counter=1,
                                                     percentage=percentage,
                                                     threshold=threshold)
        norm_prune, pindex_prune = get_nonzero_index(weight,
                                                     dim='output',
                                                     counter=1,
                                                     percentage=percentage,
                                                     threshold=threshold)
        network_pindex_decom.append(pindex_decom)
        network_norm_decom.append(norm_decom)
        network_pindex_prune.append(pindex_prune)
        network_norm_prune.append(norm_prune)
    return network_pindex_decom, network_pindex_prune, network_norm_decom, network_norm_prune
Пример #2
0
def compress_module_param(module, percentage, threshold, index_pre=None, i=0):
    conv11 = module[0][0]
    conv12 = module[0][1]
    batchnorm1 = module[1]

    ws1 = conv11.weight.shape
    weight1 = conv11.weight.data.view(ws1[0], -1).t()
    projection1 = conv12.weight.data.squeeze().t()
    bias1 = conv12.bias.data if conv12.bias is not None else None
    bn_weight1 = batchnorm1.weight.data
    bn_bias1 = batchnorm1.bias.data
    bn_mean1 = batchnorm1.running_mean.data
    bn_var1 = batchnorm1.running_var.data

    pindex1 = get_nonzero_index(projection1,
                                dim='input',
                                counter=1,
                                percentage=percentage,
                                threshold=threshold)[1]
    pindex2 = get_nonzero_index(projection1,
                                dim='output',
                                counter=1,
                                percentage=percentage,
                                threshold=threshold)[1]

    # conv11
    if index_pre is not None:
        index = torch.repeat_interleave(index_pre, ws1[2] * ws1[3]) * ws1[2] * ws1[3] \
                + torch.tensor(range(0, ws1[2] * ws1[3])).repeat(index_pre.shape[0]).cuda()
        weight1 = torch.index_select(weight1, dim=0, index=index)
    weight1 = torch.index_select(weight1, dim=1, index=pindex1)
    conv11.weight = nn.Parameter(weight1.t().view(pindex1.shape[0], -1, ws1[2],
                                                  ws1[3]))
    conv11.out_channels, conv11.in_channels = conv11.weight.size()[:2]

    # conv12: projection1, bias1
    projection1 = torch.index_select(projection1, dim=0, index=pindex1)
    if i < 11:
        projection1 = torch.index_select(projection1, dim=1, index=pindex2)
        if bias1 is not None:
            conv12.bias = nn.Parameter(
                torch.index_select(bias1, dim=0, index=pindex2))

        # compress batchnorm1
        batchnorm1.weight = nn.Parameter(
            torch.index_select(bn_weight1, dim=0, index=pindex2))
        batchnorm1.bias = nn.Parameter(
            torch.index_select(bn_bias1, dim=0, index=pindex2))
        batchnorm1.running_mean = torch.index_select(bn_mean1,
                                                     dim=0,
                                                     index=pindex2)
        batchnorm1.running_var = torch.index_select(bn_var1,
                                                    dim=0,
                                                    index=pindex2)
        batchnorm1.num_features = len(batchnorm1.weight)

    conv12.weight = nn.Parameter(projection1.t().view(
        -1, pindex1.shape[0], 1, 1))  #TODO: check this one.
    conv12.out_channels, conv12.in_channels = conv12.weight.size()[:2]
def compress_module_param(module, percentage, threshold, p1_p2_same_ratio):
    body = module._modules['body']
    conv11 = body._modules['0']._modules['0']
    conv12 = body._modules['0']._modules['1']
    conv21 = body._modules['3']._modules['0']
    conv22 = body._modules['3']._modules['1']

    ws1 = conv11.weight.shape
    weight1 = conv11.weight.data.view(ws1[0], -1).t()
    projection1 = conv12.weight.data.squeeze().t()

    ws2 = conv21.weight.shape
    weight2 = conv21.weight.data.view(ws2[0], -1).t()
    projection2 = conv22.weight.data.squeeze().t()

    _, pindex1 = get_nonzero_index(projection1,
                                   dim='input',
                                   counter=1,
                                   percentage=percentage,
                                   threshold=threshold)
    fix_channel = len(pindex1) if p1_p2_same_ratio else 0
    _, pindex2 = get_nonzero_index(projection2,
                                   dim='input',
                                   counter=1,
                                   percentage=percentage,
                                   threshold=threshold,
                                   fix_channel=fix_channel)

    # compress conv11.
    weight1 = torch.index_select(weight1, dim=1, index=pindex1)
    conv11.weight = nn.Parameter(weight1.t().view(pindex1.shape[0], ws1[1],
                                                  ws1[2], ws1[3]))
    conv11.out_channels, conv11.in_channels = conv11.weight.size()[:2]
    # compress conv12: projection1, bias1
    projection1 = torch.index_select(projection1, dim=0, index=pindex1)
    conv12.weight = nn.Parameter(projection1.t().view(ws1[0], pindex1.shape[0],
                                                      1, 1))
    conv12.out_channels, conv12.in_channels = conv12.weight.size()[:2]

    # compress conv21
    weight2 = torch.index_select(weight2, dim=1, index=pindex2)
    conv21.weight = nn.Parameter(weight2.t().view(pindex2.shape[0], ws2[1],
                                                  ws2[2], ws2[3]))
    conv21.out_channels, conv21.in_channels = conv21.weight.size()[:2]
    # compress conv22
    projection2 = torch.index_select(projection2, dim=0, index=pindex2)
    conv22.weight = nn.Parameter(projection2.t().view(ws2[0], pindex2.shape[0],
                                                      1, 1))
    conv22.out_channels, conv22.in_channels = conv22.weight.size()[:2]
Пример #4
0
 def index_pre(self, percentage, threshold):
     index = []
     for module_cur in self.find_modules():
         conv12 = module_cur[0][1]
         projection1 = conv12.weight.data.squeeze().t()
         index.append(get_nonzero_index(projection1, dim='output', counter=1, percentage=percentage, threshold=threshold)[1])
     return index
Пример #5
0
def get_compress_idx(module, percentage, threshold):
    conv12 = module[0][1]
    projection1 = conv12.weight.data.squeeze().t()
    # decomposition
    norm1, pindex1 = get_nonzero_index(projection1, dim='input', counter=1, percentage=percentage, threshold=threshold)
    # pruning
    norm2, pindex2 = get_nonzero_index(projection1, dim='output', counter=1, percentage=percentage, threshold=threshold)
    def _get_compress_statistics(norm, pindex):
        remain_norm = norm[pindex]
        channels = norm.shape[0]
        remain_channels = remain_norm.shape[0]
        remain_norm = remain_norm.detach().cpu()
        stat_channel = [channels, channels - remain_channels, (channels - remain_channels) / channels]
        stat_remain_norm = [remain_norm.max(), remain_norm.mean(), remain_norm.min()]
        return edict({'stat_channel': stat_channel, 'stat_remain_norm': stat_remain_norm,
                      'remain_norm': remain_norm, 'pindex': pindex})
    return [_get_compress_statistics(norm1, pindex1), _get_compress_statistics(norm2, pindex2)]
Пример #6
0
def get_compress_idx(module, percentage, threshold, p1_p2_same_ratio):
    weight1 = module._modules['body']._modules['0'].weight.data.squeeze().t()
    weight3 = module._modules['body']._modules['6'].weight.data.squeeze().t()

    norm1, pindex1 = get_nonzero_index(weight1, dim='output', counter=1, percentage=percentage, threshold=threshold)
    fix_channel = len(pindex1) if p1_p2_same_ratio else 0
    norm3, pindex3 = get_nonzero_index(weight3, dim='intput', counter=1, percentage=percentage, threshold=threshold, fix_channel=fix_channel)

    def _get_compress_statistics(norm, pindex):
        remain_norm = norm[pindex]
        channels = norm.shape[0]
        remain_channels = remain_norm.shape[0]
        remain_norm = remain_norm.detach().cpu()
        stat_channel = [channels, channels - remain_channels, (channels - remain_channels) / channels]
        stat_remain_norm = [remain_norm.max(), remain_norm.mean(), remain_norm.min()]
        return edict({'stat_channel': stat_channel, 'stat_remain_norm': stat_remain_norm,
                      'remain_norm': remain_norm, 'pindex': pindex})
    return [_get_compress_statistics(norm1, pindex1), _get_compress_statistics(norm3, pindex3)]
def get_compress_idx(module, percentage, threshold, p1_p2_same_ratio):
    body = module._modules['body']
    conv12 = body._modules['0']._modules['1']
    conv22 = body._modules['3']._modules['1']
    projection1 = conv12.weight.data.squeeze().t()
    projection2 = conv22.weight.data.squeeze().t()
    norm1, pindex1 = get_nonzero_index(projection1,
                                       dim='input',
                                       counter=1,
                                       percentage=percentage,
                                       threshold=threshold)
    fix_channel = len(pindex1) if p1_p2_same_ratio else 0
    norm2, pindex2 = get_nonzero_index(projection2,
                                       dim='input',
                                       counter=1,
                                       percentage=percentage,
                                       threshold=threshold,
                                       fix_channel=fix_channel)

    def _get_compress_statistics(norm, pindex):
        remain_norm = norm[pindex]
        channels = norm.shape[0]
        remain_channels = remain_norm.shape[0]
        remain_norm = remain_norm.detach().cpu()
        stat_channel = [
            channels, channels - remain_channels,
            (channels - remain_channels) / channels
        ]
        stat_remain_norm = [
            remain_norm.max(),
            remain_norm.mean(),
            remain_norm.min()
        ]
        return edict({
            'stat_channel': stat_channel,
            'stat_remain_norm': stat_remain_norm,
            'remain_norm': remain_norm,
            'pindex': pindex
        })

    return [
        _get_compress_statistics(norm1, pindex1),
        _get_compress_statistics(norm2, pindex2)
    ]
Пример #8
0
    def index_pre(self, percentage, threshold):
        index = []
        for module_cur in self.find_modules():
            conv11 = module_cur[0]
            # projection1 = conv11.weight.data.squeeze().t()

            ws1 = conv11.weight.shape
            projection1 = conv11.weight.data.view(ws1[0], -1).t()

            index.append(
                get_nonzero_index(projection1,
                                  dim='output',
                                  percentage=percentage,
                                  threshold=threshold)[1])
        return index
Пример #9
0
def compress_module_param(module, percentage, threshold, index_pre=None, i=0):
    conv11 = module[0]
    batchnorm1 = module[1]

    ws1 = conv11.weight.shape
    weight1 = conv11.weight.data.view(ws1[0], -1).t()

    bias1 = conv11.bias.data if conv11.bias is not None else None

    bn_weight1 = batchnorm1.weight.data
    bn_bias1 = batchnorm1.bias.data
    bn_mean1 = batchnorm1.running_mean.data
    bn_var1 = batchnorm1.running_var.data

    pindex1 = get_nonzero_index(weight1,
                                dim='output',
                                percentage=percentage,
                                threshold=threshold)[1]

    # conv11
    if index_pre is not None:
        index = torch.repeat_interleave(index_pre, ws1[2] * ws1[3]) * ws1[2] * ws1[3] \
                + torch.tensor(range(0, ws1[2] * ws1[3])).repeat(index_pre.shape[0]).cuda()
        weight1 = torch.index_select(weight1, dim=0, index=index)
    if i < 11:
        weight1 = torch.index_select(weight1, dim=1, index=pindex1)
        conv11.bias = nn.Parameter(
            torch.index_select(bias1, dim=0, index=pindex1))
        conv11.weight = nn.Parameter(weight1.t().view(pindex1.shape[0], -1,
                                                      ws1[2], ws1[3]))

        batchnorm1.weight = nn.Parameter(
            torch.index_select(bn_weight1, dim=0, index=pindex1))
        batchnorm1.bias = nn.Parameter(
            torch.index_select(bn_bias1, dim=0, index=pindex1))
        batchnorm1.running_mean = torch.index_select(bn_mean1,
                                                     dim=0,
                                                     index=pindex1)
        batchnorm1.running_var = torch.index_select(bn_var1,
                                                    dim=0,
                                                    index=pindex1)
        batchnorm1.num_features = len(batchnorm1.weight)
    else:
        conv11.weight = nn.Parameter(weight1.t().view(ws1[0], -1, ws1[2],
                                                      ws1[3]))

    conv11.out_channels, conv11.in_channels = conv11.weight.size()[:2]
def get_compress_idx(module, percentage, threshold):
    conv1 = module._modules['conv1']
    conv2 = module._modules['conv2']
    conv3 = module._modules['conv3']
    groups = conv2.groups
    print(conv1.weight.data.shape)
    weight1 = conv1.weight.data.squeeze().view(groups, -1)
    weight3 = conv3.weight.data.squeeze().t().reshape(groups, -1)
    joint = torch.cat([weight1, weight3], dim=1)

    norm, pindex = get_nonzero_index(joint,
                                     dim='input',
                                     counter=1,
                                     percentage=percentage,
                                     threshold=threshold)

    def _get_compress_statistics(norm, pindex):
        remain_norm = norm[pindex]
        channels = norm.shape[0]
        remain_channels = remain_norm.shape[0]
        remain_norm = remain_norm.detach().cpu()
        stat_channel = [
            channels, channels - remain_channels,
            (channels - remain_channels) / channels
        ]
        stat_remain_norm = [
            remain_norm.max(),
            remain_norm.mean(),
            remain_norm.min()
        ]
        return edict({
            'stat_channel': stat_channel,
            'stat_remain_norm': stat_remain_norm,
            'remain_norm': remain_norm,
            'pindex': pindex
        })

    return [_get_compress_statistics(norm, pindex)]
Пример #11
0
def compress_module_param(percentage, threshold, p1_p2_same_ratio, **kwargs):
    module = kwargs['module']
    
    conv1 = module._modules['body']._modules['0']
    batchnorm1 = module._modules['body']._modules['1']
    conv2 = module._modules['body']._modules['3']
    batchnorm2 = module._modules['body']._modules['4']
    conv3 = module._modules['body']._modules['6']

    weight1 = conv1.weight.data.squeeze().t()
    bn_weight1 = batchnorm1.weight.data
    bn_bias1 = batchnorm1.bias.data
    bn_mean1 = batchnorm1.running_mean.data
    bn_var1 = batchnorm1.running_var.data

    ws2 = conv2.weight.data.shape
    weight2 = conv2.weight.data.view(ws2[0], ws2[1] * ws2[2] * ws2[3]).t()
    bn_weight2 = batchnorm2.weight.data
    bn_bias2 = batchnorm2.bias.data
    bn_mean2 = batchnorm2.running_mean.data
    bn_var2 = batchnorm2.running_var.data

    weight3 = conv3.weight.data.squeeze().t()

    if 'load_original_param' in kwargs and kwargs['load_original_param']: # need to pay special attention here
        weight1_teach = kwargs['module_teacher']._modules['body']._modules['0'].weight.data.squeeze().t()
        weight3_teach = kwargs['module_teacher']._modules['body']._modules['6'].weight.data.squeeze()
    else:
        weight1_teach = weight1 #TODO: whether to use copy.copy here?
        weight3_teach = weight3
    _, pindex1 = get_nonzero_index(weight1_teach, dim='output', counter=1, percentage=percentage, threshold=threshold)
    fix_channel = len(pindex1) if p1_p2_same_ratio else 0
    _, pindex3 = get_nonzero_index(weight3_teach, dim='intput', counter=1, percentage=percentage, threshold=threshold, fix_channel=fix_channel)
    # with print_array_on_one_line():
    #     print('Index of Projection1: {}'.format(pindex1.detach().cpu().numpy()))
    #     print('Index of Projection2: {}'.format(pindex2.detach().cpu().numpy()))
    pl1, pl3 = len(pindex1), len(pindex3)
    # conv1
    conv1.weight = nn.Parameter(torch.index_select(weight1, dim=1, index=pindex1).t().view(pl1, -1, 1, 1))
    conv1.out_channels = pl1
    # batchnorm1
    batchnorm1.weight = nn.Parameter(torch.index_select(bn_weight1, dim=0, index=pindex1))
    batchnorm1.bias = nn.Parameter(torch.index_select(bn_bias1, dim=0, index=pindex1))
    batchnorm1.running_mean = torch.index_select(bn_mean1, dim=0, index=pindex1)
    batchnorm1.running_var = torch.index_select(bn_var1, dim=0, index=pindex1)
    batchnorm1.num_features = pl1
    # conv2
    index = torch.repeat_interleave(pindex1, ws2[2] * ws2[3]) * ws2[2] * ws2[3] + \
            torch.tensor(range(0, ws2[2] * ws2[3])).repeat(pindex1.shape[0]).cuda()
    weight2 = torch.index_select(weight2, dim=0, index=index)
    weight2 = torch.index_select(weight2, dim=1, index=pindex3)
    conv2.weight = nn.Parameter(weight2.view(pl3, pl1, 3, 3))
    conv2.out_channels, conv2.in_channels = pl3, pl1
    # batchnorm2
    batchnorm2.weight = nn.Parameter(torch.index_select(bn_weight2, dim=0, index=pindex3))
    batchnorm2.bias = nn.Parameter(torch.index_select(bn_bias2, dim=0, index=pindex3))
    batchnorm2.running_mean = torch.index_select(bn_mean2, dim=0, index=pindex3)
    batchnorm2.running_var = torch.index_select(bn_var2, dim=0, index=pindex3)
    batchnorm2.num_features = pl3
    # conv3
    conv3.weight = nn.Parameter(torch.index_select(weight3, dim=0, index=pindex3).view(-1, pl3, 1, 1))
    conv3.in_channels = pl3
def compress_module_param(module, percentage, threshold, p1_p2_same_ratio):
    conv11 = module._modules['conv1']._modules['0']
    conv12 = module._modules['conv1']._modules['1']
    conv21 = module._modules['conv2']._modules['0']
    conv22 = module._modules['conv2']._modules['1']
    batchnorm2 = module._modules['bn2']
    # embed()

    ws1 = conv11.weight.shape
    #weight1 = conv11.weight.data.view(ws1[0], -1).t()
    projection1 = conv12.weight.data.squeeze().t()
    bias1 = conv12.bias.data

    bn_weight2 = batchnorm2.weight.data
    bn_bias2 = batchnorm2.bias.data
    bn_mean2 = batchnorm2.running_mean.data
    bn_var2 = batchnorm2.running_var.data

    ws2 = conv21.weight.shape
    weight2 = conv21.weight.data.view(ws2[0], -1).t()
    projection2 = conv22.weight.data.squeeze().t()
    #bias2 = conv22.bias.data

    _, pindex1 = get_nonzero_index(projection1,
                                   dim='output',
                                   counter=1,
                                   percentage=percentage,
                                   threshold=threshold)
    fix_channel = len(pindex1) if p1_p2_same_ratio else 0
    _, pindex2 = get_nonzero_index(projection2,
                                   dim='input',
                                   counter=1,
                                   percentage=percentage,
                                   threshold=threshold,
                                   fix_channel=fix_channel)

    # conv11 don't need to be changed.
    # compress conv12: projection1, bias1
    projection1 = torch.index_select(projection1, dim=1, index=pindex1)
    conv12.weight = nn.Parameter(projection1.t().view(
        pindex1.shape[0], ws1[0], 1, 1))  #TODO: check this one.
    conv12.bias = nn.Parameter(torch.index_select(bias1, dim=0, index=pindex1))
    conv12.out_channels = conv12.weight.size()[0]
    # compress batchnorm2
    batchnorm2.weight = nn.Parameter(
        torch.index_select(bn_weight2, dim=0, index=pindex1))
    batchnorm2.bias = nn.Parameter(
        torch.index_select(bn_bias2, dim=0, index=pindex1))
    batchnorm2.running_mean = torch.index_select(bn_mean2,
                                                 dim=0,
                                                 index=pindex1)
    batchnorm2.running_var = torch.index_select(bn_var2, dim=0, index=pindex1)
    batchnorm2.num_features = len(batchnorm2.weight)
    # compress conv21
    index = torch.repeat_interleave(pindex1, ws2[2] * ws2[3]) * ws2[2] * ws2[3] \
            + torch.tensor(range(0, ws2[2] * ws2[3])).repeat(pindex1.shape[0]).cuda()
    weight2 = torch.index_select(weight2, dim=0, index=index)
    weight2 = torch.index_select(weight2, dim=1, index=pindex2)
    conv21.weight = nn.Parameter(weight2.t().view(pindex2.shape[0],
                                                  pindex1.shape[0], ws2[2],
                                                  ws2[3]))
    conv21.out_channels, conv21.in_channels = conv21.weight.size()[:2]
    # compress conv22
    projection2 = torch.index_select(projection2, dim=0, index=pindex2)
    conv22.weight = nn.Parameter(projection2.t().view(-1, pindex2.shape[0], 1,
                                                      1))
    conv22.in_channels = conv22.weight.size()[1]
def compress_module_param(module, percentage, threshold, p1_p2_same_ratio):
    body = module._modules['body']
    conv11 = body._modules['0']._modules['0']
    conv12 = body._modules['0']._modules['1']
    batchnorm1 = body._modules['1']
    conv21 = body._modules['3']._modules['0']
    conv22 = body._modules['3']._modules['1']
    # batchnorm2 = body._modules['4']

    ws1 = conv11.weight.shape
    #weight1 = conv11.weight.data.view(ws1[0], -1).t()
    projection1 = conv12.weight.data.squeeze().t()
    bias1 = conv12.bias.data if conv12.bias is not None else None
    bn_weight1 = batchnorm1.weight.data
    bn_bias1 = batchnorm1.bias.data
    bn_mean1 = batchnorm1.running_mean.data
    bn_var1 = batchnorm1.running_var.data

    ws2 = conv21.weight.shape
    weight2 = conv21.weight.data.view(ws2[0], -1).t()
    projection2 = conv22.weight.data.squeeze().t()
    #bias2 = conv22.bias.data if conv22.bias is not None else None
    #bn_weight2 = batchnorm2.weight.data
    #bn_bias2 = batchnorm2.bias.data
    #bn_mean2 = batchnorm2.running_mean.data
    #bn_var2 = batchnorm2.running_var.data

    _, pindex1 = get_nonzero_index(projection1, dim='output', counter=1, percentage=percentage, threshold=threshold)
    fix_channel = len(pindex1) if p1_p2_same_ratio else 0
    _, pindex2 = get_nonzero_index(projection2, dim='input', counter=1, percentage=percentage, threshold=threshold,
                                   fix_channel=fix_channel)
    # print(len(pindex1), len(pindex2))

    # with print_array_on_one_line():
    #     print('Index of Projection1: {}'.format(pindex1.detach().cpu().numpy()))
    #     print('Index of Projection2: {}'.format(pindex2.detach().cpu().numpy()))

    # conv11 don't need to be changed.
    # compress conv12: projection1, bias1
    projection1 = torch.index_select(projection1, dim=1, index=pindex1)
    conv12.weight = nn.Parameter(projection1.t().view(pindex1.shape[0], ws1[0], 1, 1)) #TODO: check this one.
    if bias1 is not None:
        conv12.bias = nn.Parameter(torch.index_select(bias1, dim=0, index=pindex1))
    conv12.out_channels = conv12.weight.size()[0]
    #bias1 = torch.mm(params.bias1.unsqueeze(dim=0), projection1).squeeze() if params.bias1 is not None else None
    # compress batchnorm1
    batchnorm1.weight = nn.Parameter(torch.index_select(bn_weight1, dim=0, index=pindex1))
    batchnorm1.bias = nn.Parameter(torch.index_select(bn_bias1, dim=0, index=pindex1))
    batchnorm1.running_mean = torch.index_select(bn_mean1, dim=0, index=pindex1)
    batchnorm1.running_var = torch.index_select(bn_var1, dim=0, index=pindex1)
    batchnorm1.num_features = len(batchnorm1.weight)
    index = torch.repeat_interleave(pindex1, ws2[2] * ws2[3]) * ws2[2] * ws2[3] \
            + torch.tensor(range(0, ws2[2] * ws2[3])).repeat(pindex1.shape[0]).cuda()
    # compress conv21
    weight2 = torch.index_select(weight2, dim=0, index=index)
    weight2 = torch.index_select(weight2, dim=1, index=pindex2)
    conv21.weight = nn.Parameter(weight2.t().view(pindex2.shape[0], pindex1.shape[0], ws2[2], ws2[3]))
    conv21.out_channels, conv21.in_channels = conv21.weight.size()[:2]
    # compress conv22
    projection2 = torch.index_select(projection2, dim=0, index=pindex2)
    conv22.weight = nn.Parameter(projection2.t().view(-1, pindex2.shape[0], 1, 1))
    conv22.in_channels = conv22.weight.size()[1]
def compress_module_param(percentage, threshold, **kwargs):
    module = kwargs['module']

    conv1 = module._modules['conv1']
    batchnorm1 = module._modules['bn1']
    conv2 = module._modules['conv2']
    batchnorm2 = module._modules['bn2']
    conv3 = module._modules['conv3']

    groups = conv2.groups
    gs = conv2.in_channels // groups

    ws1 = conv1.weight.data.shape
    weight1 = conv1.weight.data.squeeze().view(groups, gs * conv1.in_channels)
    bn_weight1 = batchnorm1.weight.data
    bn_bias1 = batchnorm1.bias.data
    bn_mean1 = batchnorm1.running_mean.data
    bn_var1 = batchnorm1.running_var.data

    ws2 = conv2.weight.data.shape
    weight2 = conv2.weight.data.view(groups, ws2[1] * ws2[2] * ws2[3] *
                                     gs)  #do not need transpose here
    bn_weight2 = batchnorm2.weight.data
    bn_bias2 = batchnorm2.bias.data
    bn_mean2 = batchnorm2.running_mean.data
    bn_var2 = batchnorm2.running_var.data

    ws3 = conv3.weight.data.shape
    weight3 = conv3.weight.data.squeeze().t().reshape(groups,
                                                      gs * conv3.out_channels)

    if 'load_original_param' in kwargs and kwargs[
            'load_original_param']:  # need to pay special attention here
        weight1_teach = kwargs['module_teacher']._modules['conv1'].weight.data.squeeze()\
            .view(conv1.groups, conv1.in_channels // conv1.groups * conv1.out_channels)
        weight3_teach = kwargs['module_teacher']._modules['conv3'].weight.data.squeeze()\
            .t().view(conv2.groups, conv2.in_channels // conv2.groups * conv2.out_channels)
        joint = torch.cat([weight1_teach, weight3_teach], dim=1)
    else:
        joint = torch.cat([weight1, weight3], dim=1)

    _, pindex = get_nonzero_index(joint,
                                  dim='input',
                                  counter=1,
                                  percentage=percentage,
                                  threshold=threshold)
    # with print_array_on_one_line():
    #     print('Index of Projection1: {}'.format(pindex1.detach().cpu().numpy()))
    #     print('Index of Projection2: {}'.format(pindex2.detach().cpu().numpy()))
    pl = len(pindex)
    conv1.weight = nn.Parameter(
        torch.index_select(weight1, dim=0,
                           index=pindex).view(pl * gs, ws1[1], 1, 1))
    conv1.out_channels = pl * gs

    batchnorm1.weight = nn.Parameter(
        torch.index_select(bn_weight1.view(groups, gs), dim=0,
                           index=pindex).view(-1))
    batchnorm1.bias = nn.Parameter(
        torch.index_select(bn_bias1.view(groups, gs), dim=0,
                           index=pindex).view(-1))
    batchnorm1.running_mean = torch.index_select(bn_mean1.view(groups, gs),
                                                 dim=0,
                                                 index=pindex).view(-1)
    batchnorm1.running_var = torch.index_select(bn_var1.view(groups, gs),
                                                dim=0,
                                                index=pindex).view(-1)
    batchnorm1.num_features = pl * gs

    conv2.weight = nn.Parameter(
        torch.index_select(weight2, dim=0,
                           index=pindex).view(pl * gs, gs, ws2[2], ws2[3]))
    conv2.out_channels, conv2.in_channels, conv2.groups = pl * gs, pl * gs, pl
    batchnorm2.weight = nn.Parameter(
        torch.index_select(bn_weight2.view(groups, gs), dim=0,
                           index=pindex).view(-1))
    batchnorm2.bias = nn.Parameter(
        torch.index_select(bn_bias2.view(groups, gs), dim=0,
                           index=pindex).view(-1))
    batchnorm2.running_mean = torch.index_select(bn_mean2.view(groups, gs),
                                                 dim=0,
                                                 index=pindex).view(-1)
    batchnorm2.running_var = torch.index_select(bn_var2.view(groups, gs),
                                                dim=0,
                                                index=pindex).view(-1)
    batchnorm2.num_features = pl * gs

    conv3.weight = nn.Parameter(
        torch.index_select(weight3, dim=0,
                           index=pindex).view(ws3[0], pl * gs, 1, 1))
    conv3.in_channels = pl * gs