def get_prune_list(resnet_channel_list, prob_list, dcfg, expand_rate=0.0001): import numpy as np prune_list = [] idx = 0 chn_input_full, chn_output_full = 3, resnet_channel_list[0] dnas_conv = lambda input, output: DNAS.Conv2d(input, output, 1, 1, 1, False, dcfg=dcfg) conv = dnas_conv(chn_input_full, chn_output_full) chn_output_prune = int(np.round( min(torch.dot(prob_list[idx], conv.out_plane_list).item(), chn_output_full) )) chn_output_prune += int(np.ceil(expand_rate*(chn_output_full-chn_output_prune))) prune_list.append(chn_output_prune) chn_input_full = chn_output_full idx += 1 for blocks in resnet_channel_list[1:]: blocks_list = [] for block in blocks: block_prune_list = [] for chn_output_full in block: conv = DNAS.Conv2d(chn_input_full, chn_output_full, 1, 1, 1, False, dcfg=dcfg) print(prob_list[idx], conv.out_plane_list, torch.dot(prob_list[idx], conv.out_plane_list).item()) chn_output_prune = int(np.round( min(torch.dot(prob_list[idx], conv.out_plane_list).item(), chn_output_full) )) chn_output_prune += int(np.ceil(expand_rate*(chn_output_full-chn_output_prune))) block_prune_list.append(chn_output_prune) chn_input_full = chn_output_full idx += 1 blocks_list.append(block_prune_list) prune_list.append(blocks_list) return prune_list
def forward(self, x, tau=1, noise=False, reuse_prob=None, rmask=None): prob = reuse_prob shortcut = x x = F.relu(DNAS.weighted_feature(self.bn0(x), rmask)) prob_list, flops_list = [], [] if self.shortcut is not None: shortcut, rmask, prob, shortcut_flops = self.shortcut(x, tau, noise, p_in=prob) prob_list.append(prob) flops_list.append(shortcut_flops) x, rmask_1, p1, conv1_flops = self.conv1(x, tau, noise, p_in=prob) x = F.relu(DNAS.weighted_feature(self.bn1(x), rmask_1)) prob_list.insert(0, p1) flops_list.insert(0, conv1_flops) x, rmask_2, p2, conv2_flops = self.conv2(x, tau, noise, p_in=p1) x = F.relu(DNAS.weighted_feature(self.bn2(x), rmask_2)) prob_list.insert(1, p2) flops_list.insert(1, conv2_flops) x, rmask_3, prob, conv3_flops = self.conv3(x, tau, noise, reuse_prob=prob, p_in=p2) prob_list.insert(2, prob) flops_list.insert(2, conv3_flops) x += shortcut x = DNAS.weighted_feature(x, rmask) return x, rmask_3, prob, prob_list, flops_list
def __init__(self, in_planes, out_planes_list, stride=2, project=False, dcfg=None): super(BottleneckGated, self).__init__() out_planes_1 = out_planes_list[0] out_planes_2 = out_planes_list[1] out_planes_3 = out_planes_list[2] assert dcfg is not None self.dcfg = dcfg self.dcfg_nonreuse = dcfg.copy() self.bn0 = nn.BatchNorm2d(in_planes, momentum=_BATCH_NORM_DECAY, eps=_EPSILON) self.conv1 = DNAS.Conv2d(in_planes, out_planes_1, kernel_size=1, stride=1, padding=0, bias=False, dcfg=self.dcfg_nonreuse) self.bn1 = nn.BatchNorm2d(out_planes_1, momentum=_BATCH_NORM_DECAY, eps=_EPSILON) self.conv2 = DNAS.Conv2d(out_planes_1, out_planes_2, kernel_size=3, stride=stride, padding=1, bias=False, dcfg=self.dcfg_nonreuse) self.bn2 = nn.BatchNorm2d(out_planes_2, momentum=_BATCH_NORM_DECAY, eps=_EPSILON) self.shortcut = None # if stride != 1 and len(out_planes_list) > 2: if project: self.shortcut = DNAS.Conv2d(in_planes, out_planes_list[-1], kernel_size=1, stride=stride, padding=0, bias=False, dcfg=self.dcfg_nonreuse) self.dcfg.reuse_gate = self.shortcut.gate self.conv3 = DNAS.Conv2d(out_planes_2, out_planes_3, kernel_size=1, stride=1, padding=0, bias=False, dcfg=self.dcfg)
def __init__(self, num_blocks, num_planes, num_colors=3, scale=1, res_scale=0.1): super(EDSRGated, self).__init__() self.num_blocks = num_blocks self.act = nn.ReLU(inplace=True) self.sub_mean = MeanShift(1) self.add_mean = MeanShift(1, sign=1) self.conv0 = dnas.Conv2d(num_colors, num_planes, 3, stride=1, padding=1, bias=False, dcfg=self.dcfg) self.dcfg.reuse_gate = self.conv0.gate self.blocks = list() for _ in range(num_blocks): self.blocks.append(EDSRBlockGated(num_planes, res_scale, self.dcfg)) m_tail = list() m_tail.append(Upsampler(scale, num_planes)) m_tail.append( nn.Conv2d(in_channels=num_planes, out_channels=num_colors, kernel_size=3, stride=1, padding=1, bias=False)) self.tail = nn.Sequential(*m_tail)
def __init__(self, n_layer, n_class, channel_lists, dcfg): super(ResNet, self).__init__() if n_layer not in cfg.keys(): print('Numer of layers Error: ', n_layer) exit(1) self.n_class = n_class self.channel_lists = channel_lists self.block_n_cell = cfg[n_layer] self.dcfg = dcfg self.base_n_channel = channel_lists[0] self.imagenet = len(self.block_n_cell) > 3 if self.imagenet: self.cell_fn = BottleneckGated if n_layer >= 50 else ResidualBlockGated self.conv0 = DNAS.Conv2d(3, self.base_n_channel, 7, stride=2, padding=3, bias=False, dcfg=self.dcfg) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: self.cell_fn = ResidualBlockGated self.conv0 = DNAS.Conv2d(3, self.base_n_channel, 3, stride=1, padding=1, bias=False, dcfg=self.dcfg) self.dcfg.reuse_gate = self.conv0.gate self.block_list = self._block_layers() self.bn = nn.BatchNorm2d(channel_lists[-1][-1][-1], momentum=_BATCH_NORM_DECAY, eps=_EPSILON) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = DNAS.Linear(channel_lists[-1][-1][-1], self.n_class, dcfg=self.dcfg) self.apply(_weights_init)
def train_prune(self, tau, n_epoch=250, load_path, save_path): self.load_model(load_path) dcfg = DNAS.DcpConfig(n_param=8, split_type=dnas.TYPE_A, reuse_gate=None) data = next(iter(self.train_loader)) in_data, im_gt = [x.to(self.device) for x in data] self.net.eval() outputs, prob_list, flops, flops_list = self.forward(in_data, tau=tau, noise=False)
def __init__(self, num_planes, res_scale=1, dcfg=None): super(EDSRBlockGated, self).__init__() assert dcfg is not None self.dcfg = dcfg self.dcfg_nonreuse = dcfg.copy() self.conv1 = dnas.Conv2d(num_planes, num_planes, kernel_size=3, stride=1, padding=1, bias=False, dcfg=self.dcfg_nonreuse) self.act = nn.ReLU(inplace=True) self.conv2 = dnas.Conv2d(num_planes, num_planes, kernel_size=3, stride=1, padding=1, bias=False, dcfg=self.dcfg) self.res_scale = res_scale
def forward(self, x, tau=1, noise=False, reuse_prob=None, rmask=None): prob = reuse_prob shortcut = x # x = self.bn0(x), x = DNAS.weighted_feature(x, rmask), x = F.relu(x) x = F.relu(DNAS.weighted_feature(self.bn0(x), rmask)) prob_list, flops_list = [], [] if self.shortcut is not None: # todo: original implementation # shortcut, prob, shortcut_flops = self.shortcut(x, tau, noise, p_in=prob) shortcut, rmask, prob, shortcut_flops = self.shortcut(x, tau, noise, p_in=prob) prob_list.append(prob) flops_list.append(shortcut_flops) # todo: original implementation # x, p0, conv1_flops = self.conv1(x, tau, noise, p_in=prob) x, rmask_1, p1, conv1_flops = self.conv1(x, tau, noise, p_in=prob) # x = self.bn1(x), x = DNAS.weighted_feature(x, rmask_1), x = F.relu(x) x = F.relu(DNAS.weighted_feature(self.bn1(x), rmask_1)) prob_list.insert(0, p1) flops_list.insert(0, conv1_flops) # todo: original implementation # x, prob, conv2_flops = self.conv2(x, tau, noise, reuse_prob=prob, p_in=p0) x, rmask_2, prob, conv2_flops = self.conv2(x, tau, noise, reuse_prob=prob, p_in=p1) prob_list.insert(1, prob) flops_list.insert(1, conv2_flops) x += shortcut x = DNAS.weighted_feature(x, rmask) # todo: the order of prob and flops should correspond to the order of channels # todo: original implementation # return x, prob, prob_list, flops_list return x, rmask_2, prob, prob_list, flops_list
def train_prune(self, tau, n_epoch=250, load_path='./models/search/model.pth', save_path='./models/prune/model.pth'): # Done, 0. load the searched model and extract the prune info # Done, 1. define the slim network based on prune info # Done, 2. train and validate, and exploit the full learner self.load_model(load_path) dcfg = DNAS.DcpConfig(n_param=8, split_type=DNAS.TYPE_A, reuse_gate=None) channel_list = ResNetChannelList(self.args.teacher_net_index) self.net.eval() data = next(iter(self.train_loader)) inputs, labels = data[0].to(self.device), data[1].to(self.device) outputs, prob_list, flops, flops_list = self.forward(inputs, tau=tau, noise=False) print('=================') print(tau) for prob in prob_list: for item in prob.tolist(): print('{0:.2f}'.format(item), end=' ') print() print('------------------') for item in self.forward.named_parameters(): # if '0.0.bn0' in item[0] and 'bias' not in item[0]: # print(item) if 'conv0.conv.weight' in item[0]: print(item[1][:, 0, 0, 0]) print(item[1][:, 1, 2, 1]) channel_list_prune = get_prune_list(channel_list, prob_list, dcfg=dcfg) # channel_list_prune = [16, # [[10, 12, 12], [7, 12], [11, 12]], # [[31, 25, 25], [19, 25], [32, 25]], # [[39, 43, 43], [61, 43], [41, 43]]] print(channel_list_prune) # exit(1) # channel_list_prune = [13, [[9, 13], [9, 13], [16, 13]], [[16, 25, 25], [27, 25], [16, 25]], [[54, 64, 64], [45, 64], [29, 64]]] # teacher_net = ResNet20(n_classes=self.dataset.n_class) # teacher = Distiller(self.dataset, teacher_net, self.device, self.args, model_path='./models/6884.pth') net = ResNetL(self.args.net_index, self.dataset.n_class, channel_list_prune) full_learner = FullLearner(self.dataset, net, device=self.device, args=self.args, teacher=self.teacher) print('FLOPs:', full_learner.cnt_flops()) full_learner.train(n_epoch=n_epoch, save_path=save_path)
def forward(self, x, tau=1, noise=False): x, rmask, prob, flops = self.conv0(x, tau, noise) # todo: original implementation # x, prob, flops = self.conv0(x, tau, noise) prob_list, flops_list = [prob], [flops] if self.imagenet: x = self.maxpool(x) for i, blocks in enumerate(self.block_list): for block in blocks: # todo: original implementation # x, prob, blk_prob_list, blk_flops_list = block(x, tau, noise, reuse_prob=prob, rmask=rmask) x, rmask, prob, blk_prob_list, blk_flops_list = block( x, tau, noise, reuse_prob=prob, rmask=rmask) prob_list += blk_prob_list flops_list += blk_flops_list prob = blk_prob_list[-1] x = F.relu(DNAS.weighted_feature(self.bn(x), rmask)) x = self.avgpool(x) x = torch.flatten(x, 1) x, fc_flops = self.fc(x, p_in=prob) flops_list += [fc_flops] return x, prob_list, torch.sum(torch.stack(flops_list)), flops_list
def ResNetGated(n_layer, n_class): dcfg = DNAS.DcpConfig(n_param=8, split_type=DNAS.TYPE_A, reuse_gate=None) channel_list = ResNetChannelList(n_layer) return ResNet(n_layer, n_class, channel_list, dcfg)