def train_prune(self, tau, n_epoch=250, load_path, save_path): self.load_model(load_path) dcfg = DNAS.DcpConfig(n_param=8, split_type=dnas.TYPE_A, reuse_gate=None) data = next(iter(self.train_loader)) in_data, im_gt = [x.to(self.device) for x in data] self.net.eval() outputs, prob_list, flops, flops_list = self.forward(in_data, tau=tau, noise=False)
def train_prune(self, tau, n_epoch=250, load_path='./models/search/model.pth', save_path='./models/prune/model.pth'): # Done, 0. load the searched model and extract the prune info # Done, 1. define the slim network based on prune info # Done, 2. train and validate, and exploit the full learner self.load_model(load_path) dcfg = DNAS.DcpConfig(n_param=8, split_type=DNAS.TYPE_A, reuse_gate=None) channel_list = ResNetChannelList(self.args.teacher_net_index) self.net.eval() data = next(iter(self.train_loader)) inputs, labels = data[0].to(self.device), data[1].to(self.device) outputs, prob_list, flops, flops_list = self.forward(inputs, tau=tau, noise=False) print('=================') print(tau) for prob in prob_list: for item in prob.tolist(): print('{0:.2f}'.format(item), end=' ') print() print('------------------') for item in self.forward.named_parameters(): # if '0.0.bn0' in item[0] and 'bias' not in item[0]: # print(item) if 'conv0.conv.weight' in item[0]: print(item[1][:, 0, 0, 0]) print(item[1][:, 1, 2, 1]) channel_list_prune = get_prune_list(channel_list, prob_list, dcfg=dcfg) # channel_list_prune = [16, # [[10, 12, 12], [7, 12], [11, 12]], # [[31, 25, 25], [19, 25], [32, 25]], # [[39, 43, 43], [61, 43], [41, 43]]] print(channel_list_prune) # exit(1) # channel_list_prune = [13, [[9, 13], [9, 13], [16, 13]], [[16, 25, 25], [27, 25], [16, 25]], [[54, 64, 64], [45, 64], [29, 64]]] # teacher_net = ResNet20(n_classes=self.dataset.n_class) # teacher = Distiller(self.dataset, teacher_net, self.device, self.args, model_path='./models/6884.pth') net = ResNetL(self.args.net_index, self.dataset.n_class, channel_list_prune) full_learner = FullLearner(self.dataset, net, device=self.device, args=self.args, teacher=self.teacher) print('FLOPs:', full_learner.cnt_flops()) full_learner.train(n_epoch=n_epoch, save_path=save_path)
def ResNetGated(n_layer, n_class): dcfg = DNAS.DcpConfig(n_param=8, split_type=DNAS.TYPE_A, reuse_gate=None) channel_list = ResNetChannelList(n_layer) return ResNet(n_layer, n_class, channel_list, dcfg)