def compute_pindex(modules, percentage, threshold): network_pindex_decom = [] network_norm_decom = [] network_pindex_prune = [] network_norm_prune = [] for m in modules: if isinstance(m, Dense): conv2 = m._modules['body']._modules['3'] elif isinstance(m, Transition): conv2 = m._modules['3'] else: raise NotImplementedError('Do not need to prune the layer ' + m.__class__.__name__) weight = conv2.weight.squeeze().t() norm_decom, pindex_decom = get_nonzero_index(weight, dim='input', counter=1, percentage=percentage, threshold=threshold) norm_prune, pindex_prune = get_nonzero_index(weight, dim='output', counter=1, percentage=percentage, threshold=threshold) network_pindex_decom.append(pindex_decom) network_norm_decom.append(norm_decom) network_pindex_prune.append(pindex_prune) network_norm_prune.append(norm_prune) return network_pindex_decom, network_pindex_prune, network_norm_decom, network_norm_prune
def compress_module_param(module, percentage, threshold, index_pre=None, i=0): conv11 = module[0][0] conv12 = module[0][1] batchnorm1 = module[1] ws1 = conv11.weight.shape weight1 = conv11.weight.data.view(ws1[0], -1).t() projection1 = conv12.weight.data.squeeze().t() bias1 = conv12.bias.data if conv12.bias is not None else None bn_weight1 = batchnorm1.weight.data bn_bias1 = batchnorm1.bias.data bn_mean1 = batchnorm1.running_mean.data bn_var1 = batchnorm1.running_var.data pindex1 = get_nonzero_index(projection1, dim='input', counter=1, percentage=percentage, threshold=threshold)[1] pindex2 = get_nonzero_index(projection1, dim='output', counter=1, percentage=percentage, threshold=threshold)[1] # conv11 if index_pre is not None: index = torch.repeat_interleave(index_pre, ws1[2] * ws1[3]) * ws1[2] * ws1[3] \ + torch.tensor(range(0, ws1[2] * ws1[3])).repeat(index_pre.shape[0]).cuda() weight1 = torch.index_select(weight1, dim=0, index=index) weight1 = torch.index_select(weight1, dim=1, index=pindex1) conv11.weight = nn.Parameter(weight1.t().view(pindex1.shape[0], -1, ws1[2], ws1[3])) conv11.out_channels, conv11.in_channels = conv11.weight.size()[:2] # conv12: projection1, bias1 projection1 = torch.index_select(projection1, dim=0, index=pindex1) if i < 11: projection1 = torch.index_select(projection1, dim=1, index=pindex2) if bias1 is not None: conv12.bias = nn.Parameter( torch.index_select(bias1, dim=0, index=pindex2)) # compress batchnorm1 batchnorm1.weight = nn.Parameter( torch.index_select(bn_weight1, dim=0, index=pindex2)) batchnorm1.bias = nn.Parameter( torch.index_select(bn_bias1, dim=0, index=pindex2)) batchnorm1.running_mean = torch.index_select(bn_mean1, dim=0, index=pindex2) batchnorm1.running_var = torch.index_select(bn_var1, dim=0, index=pindex2) batchnorm1.num_features = len(batchnorm1.weight) conv12.weight = nn.Parameter(projection1.t().view( -1, pindex1.shape[0], 1, 1)) #TODO: check this one. conv12.out_channels, conv12.in_channels = conv12.weight.size()[:2]
def compress_module_param(module, percentage, threshold, p1_p2_same_ratio): body = module._modules['body'] conv11 = body._modules['0']._modules['0'] conv12 = body._modules['0']._modules['1'] conv21 = body._modules['3']._modules['0'] conv22 = body._modules['3']._modules['1'] ws1 = conv11.weight.shape weight1 = conv11.weight.data.view(ws1[0], -1).t() projection1 = conv12.weight.data.squeeze().t() ws2 = conv21.weight.shape weight2 = conv21.weight.data.view(ws2[0], -1).t() projection2 = conv22.weight.data.squeeze().t() _, pindex1 = get_nonzero_index(projection1, dim='input', counter=1, percentage=percentage, threshold=threshold) fix_channel = len(pindex1) if p1_p2_same_ratio else 0 _, pindex2 = get_nonzero_index(projection2, dim='input', counter=1, percentage=percentage, threshold=threshold, fix_channel=fix_channel) # compress conv11. weight1 = torch.index_select(weight1, dim=1, index=pindex1) conv11.weight = nn.Parameter(weight1.t().view(pindex1.shape[0], ws1[1], ws1[2], ws1[3])) conv11.out_channels, conv11.in_channels = conv11.weight.size()[:2] # compress conv12: projection1, bias1 projection1 = torch.index_select(projection1, dim=0, index=pindex1) conv12.weight = nn.Parameter(projection1.t().view(ws1[0], pindex1.shape[0], 1, 1)) conv12.out_channels, conv12.in_channels = conv12.weight.size()[:2] # compress conv21 weight2 = torch.index_select(weight2, dim=1, index=pindex2) conv21.weight = nn.Parameter(weight2.t().view(pindex2.shape[0], ws2[1], ws2[2], ws2[3])) conv21.out_channels, conv21.in_channels = conv21.weight.size()[:2] # compress conv22 projection2 = torch.index_select(projection2, dim=0, index=pindex2) conv22.weight = nn.Parameter(projection2.t().view(ws2[0], pindex2.shape[0], 1, 1)) conv22.out_channels, conv22.in_channels = conv22.weight.size()[:2]
def index_pre(self, percentage, threshold): index = [] for module_cur in self.find_modules(): conv12 = module_cur[0][1] projection1 = conv12.weight.data.squeeze().t() index.append(get_nonzero_index(projection1, dim='output', counter=1, percentage=percentage, threshold=threshold)[1]) return index
def get_compress_idx(module, percentage, threshold): conv12 = module[0][1] projection1 = conv12.weight.data.squeeze().t() # decomposition norm1, pindex1 = get_nonzero_index(projection1, dim='input', counter=1, percentage=percentage, threshold=threshold) # pruning norm2, pindex2 = get_nonzero_index(projection1, dim='output', counter=1, percentage=percentage, threshold=threshold) def _get_compress_statistics(norm, pindex): remain_norm = norm[pindex] channels = norm.shape[0] remain_channels = remain_norm.shape[0] remain_norm = remain_norm.detach().cpu() stat_channel = [channels, channels - remain_channels, (channels - remain_channels) / channels] stat_remain_norm = [remain_norm.max(), remain_norm.mean(), remain_norm.min()] return edict({'stat_channel': stat_channel, 'stat_remain_norm': stat_remain_norm, 'remain_norm': remain_norm, 'pindex': pindex}) return [_get_compress_statistics(norm1, pindex1), _get_compress_statistics(norm2, pindex2)]
def get_compress_idx(module, percentage, threshold, p1_p2_same_ratio): weight1 = module._modules['body']._modules['0'].weight.data.squeeze().t() weight3 = module._modules['body']._modules['6'].weight.data.squeeze().t() norm1, pindex1 = get_nonzero_index(weight1, dim='output', counter=1, percentage=percentage, threshold=threshold) fix_channel = len(pindex1) if p1_p2_same_ratio else 0 norm3, pindex3 = get_nonzero_index(weight3, dim='intput', counter=1, percentage=percentage, threshold=threshold, fix_channel=fix_channel) def _get_compress_statistics(norm, pindex): remain_norm = norm[pindex] channels = norm.shape[0] remain_channels = remain_norm.shape[0] remain_norm = remain_norm.detach().cpu() stat_channel = [channels, channels - remain_channels, (channels - remain_channels) / channels] stat_remain_norm = [remain_norm.max(), remain_norm.mean(), remain_norm.min()] return edict({'stat_channel': stat_channel, 'stat_remain_norm': stat_remain_norm, 'remain_norm': remain_norm, 'pindex': pindex}) return [_get_compress_statistics(norm1, pindex1), _get_compress_statistics(norm3, pindex3)]
def get_compress_idx(module, percentage, threshold, p1_p2_same_ratio): body = module._modules['body'] conv12 = body._modules['0']._modules['1'] conv22 = body._modules['3']._modules['1'] projection1 = conv12.weight.data.squeeze().t() projection2 = conv22.weight.data.squeeze().t() norm1, pindex1 = get_nonzero_index(projection1, dim='input', counter=1, percentage=percentage, threshold=threshold) fix_channel = len(pindex1) if p1_p2_same_ratio else 0 norm2, pindex2 = get_nonzero_index(projection2, dim='input', counter=1, percentage=percentage, threshold=threshold, fix_channel=fix_channel) def _get_compress_statistics(norm, pindex): remain_norm = norm[pindex] channels = norm.shape[0] remain_channels = remain_norm.shape[0] remain_norm = remain_norm.detach().cpu() stat_channel = [ channels, channels - remain_channels, (channels - remain_channels) / channels ] stat_remain_norm = [ remain_norm.max(), remain_norm.mean(), remain_norm.min() ] return edict({ 'stat_channel': stat_channel, 'stat_remain_norm': stat_remain_norm, 'remain_norm': remain_norm, 'pindex': pindex }) return [ _get_compress_statistics(norm1, pindex1), _get_compress_statistics(norm2, pindex2) ]
def index_pre(self, percentage, threshold): index = [] for module_cur in self.find_modules(): conv11 = module_cur[0] # projection1 = conv11.weight.data.squeeze().t() ws1 = conv11.weight.shape projection1 = conv11.weight.data.view(ws1[0], -1).t() index.append( get_nonzero_index(projection1, dim='output', percentage=percentage, threshold=threshold)[1]) return index
def compress_module_param(module, percentage, threshold, index_pre=None, i=0): conv11 = module[0] batchnorm1 = module[1] ws1 = conv11.weight.shape weight1 = conv11.weight.data.view(ws1[0], -1).t() bias1 = conv11.bias.data if conv11.bias is not None else None bn_weight1 = batchnorm1.weight.data bn_bias1 = batchnorm1.bias.data bn_mean1 = batchnorm1.running_mean.data bn_var1 = batchnorm1.running_var.data pindex1 = get_nonzero_index(weight1, dim='output', percentage=percentage, threshold=threshold)[1] # conv11 if index_pre is not None: index = torch.repeat_interleave(index_pre, ws1[2] * ws1[3]) * ws1[2] * ws1[3] \ + torch.tensor(range(0, ws1[2] * ws1[3])).repeat(index_pre.shape[0]).cuda() weight1 = torch.index_select(weight1, dim=0, index=index) if i < 11: weight1 = torch.index_select(weight1, dim=1, index=pindex1) conv11.bias = nn.Parameter( torch.index_select(bias1, dim=0, index=pindex1)) conv11.weight = nn.Parameter(weight1.t().view(pindex1.shape[0], -1, ws1[2], ws1[3])) batchnorm1.weight = nn.Parameter( torch.index_select(bn_weight1, dim=0, index=pindex1)) batchnorm1.bias = nn.Parameter( torch.index_select(bn_bias1, dim=0, index=pindex1)) batchnorm1.running_mean = torch.index_select(bn_mean1, dim=0, index=pindex1) batchnorm1.running_var = torch.index_select(bn_var1, dim=0, index=pindex1) batchnorm1.num_features = len(batchnorm1.weight) else: conv11.weight = nn.Parameter(weight1.t().view(ws1[0], -1, ws1[2], ws1[3])) conv11.out_channels, conv11.in_channels = conv11.weight.size()[:2]
def get_compress_idx(module, percentage, threshold): conv1 = module._modules['conv1'] conv2 = module._modules['conv2'] conv3 = module._modules['conv3'] groups = conv2.groups print(conv1.weight.data.shape) weight1 = conv1.weight.data.squeeze().view(groups, -1) weight3 = conv3.weight.data.squeeze().t().reshape(groups, -1) joint = torch.cat([weight1, weight3], dim=1) norm, pindex = get_nonzero_index(joint, dim='input', counter=1, percentage=percentage, threshold=threshold) def _get_compress_statistics(norm, pindex): remain_norm = norm[pindex] channels = norm.shape[0] remain_channels = remain_norm.shape[0] remain_norm = remain_norm.detach().cpu() stat_channel = [ channels, channels - remain_channels, (channels - remain_channels) / channels ] stat_remain_norm = [ remain_norm.max(), remain_norm.mean(), remain_norm.min() ] return edict({ 'stat_channel': stat_channel, 'stat_remain_norm': stat_remain_norm, 'remain_norm': remain_norm, 'pindex': pindex }) return [_get_compress_statistics(norm, pindex)]
def compress_module_param(percentage, threshold, p1_p2_same_ratio, **kwargs): module = kwargs['module'] conv1 = module._modules['body']._modules['0'] batchnorm1 = module._modules['body']._modules['1'] conv2 = module._modules['body']._modules['3'] batchnorm2 = module._modules['body']._modules['4'] conv3 = module._modules['body']._modules['6'] weight1 = conv1.weight.data.squeeze().t() bn_weight1 = batchnorm1.weight.data bn_bias1 = batchnorm1.bias.data bn_mean1 = batchnorm1.running_mean.data bn_var1 = batchnorm1.running_var.data ws2 = conv2.weight.data.shape weight2 = conv2.weight.data.view(ws2[0], ws2[1] * ws2[2] * ws2[3]).t() bn_weight2 = batchnorm2.weight.data bn_bias2 = batchnorm2.bias.data bn_mean2 = batchnorm2.running_mean.data bn_var2 = batchnorm2.running_var.data weight3 = conv3.weight.data.squeeze().t() if 'load_original_param' in kwargs and kwargs['load_original_param']: # need to pay special attention here weight1_teach = kwargs['module_teacher']._modules['body']._modules['0'].weight.data.squeeze().t() weight3_teach = kwargs['module_teacher']._modules['body']._modules['6'].weight.data.squeeze() else: weight1_teach = weight1 #TODO: whether to use copy.copy here? weight3_teach = weight3 _, pindex1 = get_nonzero_index(weight1_teach, dim='output', counter=1, percentage=percentage, threshold=threshold) fix_channel = len(pindex1) if p1_p2_same_ratio else 0 _, pindex3 = get_nonzero_index(weight3_teach, dim='intput', counter=1, percentage=percentage, threshold=threshold, fix_channel=fix_channel) # with print_array_on_one_line(): # print('Index of Projection1: {}'.format(pindex1.detach().cpu().numpy())) # print('Index of Projection2: {}'.format(pindex2.detach().cpu().numpy())) pl1, pl3 = len(pindex1), len(pindex3) # conv1 conv1.weight = nn.Parameter(torch.index_select(weight1, dim=1, index=pindex1).t().view(pl1, -1, 1, 1)) conv1.out_channels = pl1 # batchnorm1 batchnorm1.weight = nn.Parameter(torch.index_select(bn_weight1, dim=0, index=pindex1)) batchnorm1.bias = nn.Parameter(torch.index_select(bn_bias1, dim=0, index=pindex1)) batchnorm1.running_mean = torch.index_select(bn_mean1, dim=0, index=pindex1) batchnorm1.running_var = torch.index_select(bn_var1, dim=0, index=pindex1) batchnorm1.num_features = pl1 # conv2 index = torch.repeat_interleave(pindex1, ws2[2] * ws2[3]) * ws2[2] * ws2[3] + \ torch.tensor(range(0, ws2[2] * ws2[3])).repeat(pindex1.shape[0]).cuda() weight2 = torch.index_select(weight2, dim=0, index=index) weight2 = torch.index_select(weight2, dim=1, index=pindex3) conv2.weight = nn.Parameter(weight2.view(pl3, pl1, 3, 3)) conv2.out_channels, conv2.in_channels = pl3, pl1 # batchnorm2 batchnorm2.weight = nn.Parameter(torch.index_select(bn_weight2, dim=0, index=pindex3)) batchnorm2.bias = nn.Parameter(torch.index_select(bn_bias2, dim=0, index=pindex3)) batchnorm2.running_mean = torch.index_select(bn_mean2, dim=0, index=pindex3) batchnorm2.running_var = torch.index_select(bn_var2, dim=0, index=pindex3) batchnorm2.num_features = pl3 # conv3 conv3.weight = nn.Parameter(torch.index_select(weight3, dim=0, index=pindex3).view(-1, pl3, 1, 1)) conv3.in_channels = pl3
def compress_module_param(module, percentage, threshold, p1_p2_same_ratio): conv11 = module._modules['conv1']._modules['0'] conv12 = module._modules['conv1']._modules['1'] conv21 = module._modules['conv2']._modules['0'] conv22 = module._modules['conv2']._modules['1'] batchnorm2 = module._modules['bn2'] # embed() ws1 = conv11.weight.shape #weight1 = conv11.weight.data.view(ws1[0], -1).t() projection1 = conv12.weight.data.squeeze().t() bias1 = conv12.bias.data bn_weight2 = batchnorm2.weight.data bn_bias2 = batchnorm2.bias.data bn_mean2 = batchnorm2.running_mean.data bn_var2 = batchnorm2.running_var.data ws2 = conv21.weight.shape weight2 = conv21.weight.data.view(ws2[0], -1).t() projection2 = conv22.weight.data.squeeze().t() #bias2 = conv22.bias.data _, pindex1 = get_nonzero_index(projection1, dim='output', counter=1, percentage=percentage, threshold=threshold) fix_channel = len(pindex1) if p1_p2_same_ratio else 0 _, pindex2 = get_nonzero_index(projection2, dim='input', counter=1, percentage=percentage, threshold=threshold, fix_channel=fix_channel) # conv11 don't need to be changed. # compress conv12: projection1, bias1 projection1 = torch.index_select(projection1, dim=1, index=pindex1) conv12.weight = nn.Parameter(projection1.t().view( pindex1.shape[0], ws1[0], 1, 1)) #TODO: check this one. conv12.bias = nn.Parameter(torch.index_select(bias1, dim=0, index=pindex1)) conv12.out_channels = conv12.weight.size()[0] # compress batchnorm2 batchnorm2.weight = nn.Parameter( torch.index_select(bn_weight2, dim=0, index=pindex1)) batchnorm2.bias = nn.Parameter( torch.index_select(bn_bias2, dim=0, index=pindex1)) batchnorm2.running_mean = torch.index_select(bn_mean2, dim=0, index=pindex1) batchnorm2.running_var = torch.index_select(bn_var2, dim=0, index=pindex1) batchnorm2.num_features = len(batchnorm2.weight) # compress conv21 index = torch.repeat_interleave(pindex1, ws2[2] * ws2[3]) * ws2[2] * ws2[3] \ + torch.tensor(range(0, ws2[2] * ws2[3])).repeat(pindex1.shape[0]).cuda() weight2 = torch.index_select(weight2, dim=0, index=index) weight2 = torch.index_select(weight2, dim=1, index=pindex2) conv21.weight = nn.Parameter(weight2.t().view(pindex2.shape[0], pindex1.shape[0], ws2[2], ws2[3])) conv21.out_channels, conv21.in_channels = conv21.weight.size()[:2] # compress conv22 projection2 = torch.index_select(projection2, dim=0, index=pindex2) conv22.weight = nn.Parameter(projection2.t().view(-1, pindex2.shape[0], 1, 1)) conv22.in_channels = conv22.weight.size()[1]
def compress_module_param(module, percentage, threshold, p1_p2_same_ratio): body = module._modules['body'] conv11 = body._modules['0']._modules['0'] conv12 = body._modules['0']._modules['1'] batchnorm1 = body._modules['1'] conv21 = body._modules['3']._modules['0'] conv22 = body._modules['3']._modules['1'] # batchnorm2 = body._modules['4'] ws1 = conv11.weight.shape #weight1 = conv11.weight.data.view(ws1[0], -1).t() projection1 = conv12.weight.data.squeeze().t() bias1 = conv12.bias.data if conv12.bias is not None else None bn_weight1 = batchnorm1.weight.data bn_bias1 = batchnorm1.bias.data bn_mean1 = batchnorm1.running_mean.data bn_var1 = batchnorm1.running_var.data ws2 = conv21.weight.shape weight2 = conv21.weight.data.view(ws2[0], -1).t() projection2 = conv22.weight.data.squeeze().t() #bias2 = conv22.bias.data if conv22.bias is not None else None #bn_weight2 = batchnorm2.weight.data #bn_bias2 = batchnorm2.bias.data #bn_mean2 = batchnorm2.running_mean.data #bn_var2 = batchnorm2.running_var.data _, pindex1 = get_nonzero_index(projection1, dim='output', counter=1, percentage=percentage, threshold=threshold) fix_channel = len(pindex1) if p1_p2_same_ratio else 0 _, pindex2 = get_nonzero_index(projection2, dim='input', counter=1, percentage=percentage, threshold=threshold, fix_channel=fix_channel) # print(len(pindex1), len(pindex2)) # with print_array_on_one_line(): # print('Index of Projection1: {}'.format(pindex1.detach().cpu().numpy())) # print('Index of Projection2: {}'.format(pindex2.detach().cpu().numpy())) # conv11 don't need to be changed. # compress conv12: projection1, bias1 projection1 = torch.index_select(projection1, dim=1, index=pindex1) conv12.weight = nn.Parameter(projection1.t().view(pindex1.shape[0], ws1[0], 1, 1)) #TODO: check this one. if bias1 is not None: conv12.bias = nn.Parameter(torch.index_select(bias1, dim=0, index=pindex1)) conv12.out_channels = conv12.weight.size()[0] #bias1 = torch.mm(params.bias1.unsqueeze(dim=0), projection1).squeeze() if params.bias1 is not None else None # compress batchnorm1 batchnorm1.weight = nn.Parameter(torch.index_select(bn_weight1, dim=0, index=pindex1)) batchnorm1.bias = nn.Parameter(torch.index_select(bn_bias1, dim=0, index=pindex1)) batchnorm1.running_mean = torch.index_select(bn_mean1, dim=0, index=pindex1) batchnorm1.running_var = torch.index_select(bn_var1, dim=0, index=pindex1) batchnorm1.num_features = len(batchnorm1.weight) index = torch.repeat_interleave(pindex1, ws2[2] * ws2[3]) * ws2[2] * ws2[3] \ + torch.tensor(range(0, ws2[2] * ws2[3])).repeat(pindex1.shape[0]).cuda() # compress conv21 weight2 = torch.index_select(weight2, dim=0, index=index) weight2 = torch.index_select(weight2, dim=1, index=pindex2) conv21.weight = nn.Parameter(weight2.t().view(pindex2.shape[0], pindex1.shape[0], ws2[2], ws2[3])) conv21.out_channels, conv21.in_channels = conv21.weight.size()[:2] # compress conv22 projection2 = torch.index_select(projection2, dim=0, index=pindex2) conv22.weight = nn.Parameter(projection2.t().view(-1, pindex2.shape[0], 1, 1)) conv22.in_channels = conv22.weight.size()[1]
def compress_module_param(percentage, threshold, **kwargs): module = kwargs['module'] conv1 = module._modules['conv1'] batchnorm1 = module._modules['bn1'] conv2 = module._modules['conv2'] batchnorm2 = module._modules['bn2'] conv3 = module._modules['conv3'] groups = conv2.groups gs = conv2.in_channels // groups ws1 = conv1.weight.data.shape weight1 = conv1.weight.data.squeeze().view(groups, gs * conv1.in_channels) bn_weight1 = batchnorm1.weight.data bn_bias1 = batchnorm1.bias.data bn_mean1 = batchnorm1.running_mean.data bn_var1 = batchnorm1.running_var.data ws2 = conv2.weight.data.shape weight2 = conv2.weight.data.view(groups, ws2[1] * ws2[2] * ws2[3] * gs) #do not need transpose here bn_weight2 = batchnorm2.weight.data bn_bias2 = batchnorm2.bias.data bn_mean2 = batchnorm2.running_mean.data bn_var2 = batchnorm2.running_var.data ws3 = conv3.weight.data.shape weight3 = conv3.weight.data.squeeze().t().reshape(groups, gs * conv3.out_channels) if 'load_original_param' in kwargs and kwargs[ 'load_original_param']: # need to pay special attention here weight1_teach = kwargs['module_teacher']._modules['conv1'].weight.data.squeeze()\ .view(conv1.groups, conv1.in_channels // conv1.groups * conv1.out_channels) weight3_teach = kwargs['module_teacher']._modules['conv3'].weight.data.squeeze()\ .t().view(conv2.groups, conv2.in_channels // conv2.groups * conv2.out_channels) joint = torch.cat([weight1_teach, weight3_teach], dim=1) else: joint = torch.cat([weight1, weight3], dim=1) _, pindex = get_nonzero_index(joint, dim='input', counter=1, percentage=percentage, threshold=threshold) # with print_array_on_one_line(): # print('Index of Projection1: {}'.format(pindex1.detach().cpu().numpy())) # print('Index of Projection2: {}'.format(pindex2.detach().cpu().numpy())) pl = len(pindex) conv1.weight = nn.Parameter( torch.index_select(weight1, dim=0, index=pindex).view(pl * gs, ws1[1], 1, 1)) conv1.out_channels = pl * gs batchnorm1.weight = nn.Parameter( torch.index_select(bn_weight1.view(groups, gs), dim=0, index=pindex).view(-1)) batchnorm1.bias = nn.Parameter( torch.index_select(bn_bias1.view(groups, gs), dim=0, index=pindex).view(-1)) batchnorm1.running_mean = torch.index_select(bn_mean1.view(groups, gs), dim=0, index=pindex).view(-1) batchnorm1.running_var = torch.index_select(bn_var1.view(groups, gs), dim=0, index=pindex).view(-1) batchnorm1.num_features = pl * gs conv2.weight = nn.Parameter( torch.index_select(weight2, dim=0, index=pindex).view(pl * gs, gs, ws2[2], ws2[3])) conv2.out_channels, conv2.in_channels, conv2.groups = pl * gs, pl * gs, pl batchnorm2.weight = nn.Parameter( torch.index_select(bn_weight2.view(groups, gs), dim=0, index=pindex).view(-1)) batchnorm2.bias = nn.Parameter( torch.index_select(bn_bias2.view(groups, gs), dim=0, index=pindex).view(-1)) batchnorm2.running_mean = torch.index_select(bn_mean2.view(groups, gs), dim=0, index=pindex).view(-1) batchnorm2.running_var = torch.index_select(bn_var2.view(groups, gs), dim=0, index=pindex).view(-1) batchnorm2.num_features = pl * gs conv3.weight = nn.Parameter( torch.index_select(weight3, dim=0, index=pindex).view(ws3[0], pl * gs, 1, 1)) conv3.in_channels = pl * gs