Esempio n. 1
0
 def train_prune(self, tau, n_epoch=250, load_path, save_path):
     self.load_model(load_path)
     dcfg = DNAS.DcpConfig(n_param=8,
                           split_type=dnas.TYPE_A,
                           reuse_gate=None)
     data = next(iter(self.train_loader))
     in_data, im_gt = [x.to(self.device) for x in data]
     self.net.eval()
     outputs, prob_list, flops, flops_list = self.forward(in_data,
                                                          tau=tau,
                                                          noise=False)
Esempio n. 2
0
    def train_prune(self, tau, n_epoch=250,
                    load_path='./models/search/model.pth',
                    save_path='./models/prune/model.pth'):
        # Done, 0. load the searched model and extract the prune info
        # Done, 1. define the slim network based on prune info
        # Done, 2. train and validate, and exploit the full learner
        self.load_model(load_path)
        dcfg = DNAS.DcpConfig(n_param=8, split_type=DNAS.TYPE_A, reuse_gate=None)
        channel_list = ResNetChannelList(self.args.teacher_net_index)

        self.net.eval()
        data = next(iter(self.train_loader))
        inputs, labels = data[0].to(self.device), data[1].to(self.device)
        outputs, prob_list, flops, flops_list = self.forward(inputs, tau=tau, noise=False)
        print('=================')
        print(tau)
        for prob in prob_list:
            for item in prob.tolist():
                print('{0:.2f}'.format(item), end=' ')
            print()
        print('------------------')

        for item in self.forward.named_parameters():
            # if '0.0.bn0' in item[0] and 'bias' not in item[0]:
            #     print(item)
            if 'conv0.conv.weight' in item[0]:
                print(item[1][:, 0, 0, 0])
                print(item[1][:, 1, 2, 1])

        channel_list_prune = get_prune_list(channel_list, prob_list, dcfg=dcfg)
        # channel_list_prune = [16,
        #                       [[10, 12, 12], [7, 12], [11, 12]],
        #                       [[31, 25, 25], [19, 25], [32, 25]],
        #                       [[39, 43, 43], [61, 43], [41, 43]]]

        print(channel_list_prune)
        # exit(1)
        # channel_list_prune = [13, [[9, 13], [9, 13], [16, 13]], [[16, 25, 25], [27, 25], [16, 25]], [[54, 64, 64], [45, 64], [29, 64]]]
        # teacher_net = ResNet20(n_classes=self.dataset.n_class)
        # teacher = Distiller(self.dataset, teacher_net, self.device, self.args, model_path='./models/6884.pth')

        net = ResNetL(self.args.net_index, self.dataset.n_class, channel_list_prune)
        full_learner = FullLearner(self.dataset, net, device=self.device, args=self.args, teacher=self.teacher)
        print('FLOPs:', full_learner.cnt_flops())
        full_learner.train(n_epoch=n_epoch, save_path=save_path)
Esempio n. 3
0
def ResNetGated(n_layer, n_class):
    dcfg = DNAS.DcpConfig(n_param=8, split_type=DNAS.TYPE_A, reuse_gate=None)
    channel_list = ResNetChannelList(n_layer)
    return ResNet(n_layer, n_class, channel_list, dcfg)