示例#1
0
def dynamic_prune_inactive_neural_with_feature_extractor(net,
                                                         net_name,
                                                         exp_name,
                                                         target_accuracy,
                                                         initial_prune_rate,
                                                         delta_prune_rate=0.02,
                                                         round_for_train=2,
                                                         tar_acc_gradual_decent=False,
                                                         flop_expected=None,
                                                         dataset_name='imagenet',
                                                         batch_size=conf.batch_size,
                                                         num_workers=conf.num_workers,
                                                         learning_rate=0.01,
                                                         evaluate_step=1200,
                                                         num_epoch=450,
                                                         num_epoch_after_finetune=20,
                                                         filter_preserve_ratio=0.3,
                                                         # max_filters_pruned_for_one_time=0.5,
                                                         optimizer=optim.Adam,
                                                         learning_rate_decay=False,
                                                         learning_rate_decay_factor=conf.learning_rate_decay_factor,
                                                         weight_decay=conf.weight_decay,
                                                         learning_rate_decay_epoch=conf.learning_rate_decay_epoch,
                                                         max_training_round=9999,
                                                         round=1,
                                                         top_acc=1,
                                                         max_data_to_test=10000,
                                                         **kwargs
                                                         ):
    '''

       :param net:
       :param net_name:
       :param exp_name:
       :param target_accuracy:
       :param initial_prune_rate:
       :param delta_prune_rate: increase/decrease of prune_rate after each round
       :param round_for_train:
       :param tar_acc_gradual_decent:
       :param flop_expected:
       :param dataset_name:
       :param batch_size:
       :param num_workers:
       :param optimizer:
       :param learning_rate:
       :param evaluate_step:
       :param num_epoch:
       :param num_epoch_after_finetune: epoch to train after fine-tune the pruned net
       :param filter_preserve_ratio:
       :param max_filters_pruned_for_one_time:
       :param learning_rate_decay:
       :param learning_rate_decay_factor:
       :param weight_decay:
       :param learning_rate_decay_epoch:
       :param max_training_round:if the net can't reach target accuracy in max_training_round , the program stop.
       :param top_acc:
       :param kwargs:
       :return:
       '''

    # save the output to log
    print('save log in:' + os.path.join(conf.root_path, 'model_saved', exp_name, 'log.txt'))
    if not os.path.exists(os.path.join(conf.root_path, 'model_saved', exp_name)):
        os.makedirs(os.path.join(conf.root_path, 'model_saved', exp_name), exist_ok=True)
    sys.stdout = logger.Logger(os.path.join(conf.root_path, 'model_saved', exp_name, 'log.txt'), sys.stdout)
    sys.stderr = logger.Logger(os.path.join(conf.root_path, 'model_saved', exp_name, 'log.txt'),
                               sys.stderr)  # redirect std err, if necessary

    print('net:', net)
    print('net_name:', net_name)
    print('exp_name:', exp_name)
    print('target_accuracy:', target_accuracy)
    print('initial_prune_rate:', initial_prune_rate)
    print('delta_prune_rate:',delta_prune_rate)
    print('round_for_train:', round_for_train)
    print('tar_acc_gradual_decent:', tar_acc_gradual_decent)
    print('flop_expected:', flop_expected)
    print('dataset_name:', dataset_name)
    print('batch_size:', batch_size)
    print('num_workers:', num_workers)
    print('optimizer:', optimizer)
    print('learning_rate:', learning_rate)
    print('evaluate_step:', evaluate_step)
    print('num_epoch:', num_epoch)
    print('num_epoch_after_finetune:',num_epoch_after_finetune)
    print('filter_preserve_ratio:', filter_preserve_ratio)
    # print('max_filters_pruned_for_one_time:', max_filters_pruned_for_one_time)
    print('learning_rate_decay:', learning_rate_decay)
    print('learning_rate_decay_factor:', learning_rate_decay_factor)
    print('weight_decay:', weight_decay)
    print('learning_rate_decay_epoch:', learning_rate_decay_epoch)
    print('max_training_round:', max_training_round)
    print('top_acc:', top_acc)
    print('round:', round)
    print(kwargs)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('using: ', end='')
    if torch.cuda.is_available():
        print(torch.cuda.device_count(),' * ',end='')
        print(torch.cuda.get_device_name(torch.cuda.current_device()))
    else:
        print(device)

    checkpoint_path=os.path.join(conf.root_path, 'model_saved', exp_name)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path, exist_ok=True)

    validation_loader = data_loader.create_validation_loader(
        batch_size=batch_size,
        num_workers=num_workers,
        dataset_name=dataset_name,
    )

    if isinstance(net,nn.DataParallel):
        net_entity=net.module
    else:
        net_entity=net

    flop_original_net=flop_before_prune = measure_model(net_entity.prune(), dataset_name)

    original_accuracy = evaluate.evaluate_net(net=net,
                                              data_loader=validation_loader,
                                              save_net=False,
                                              dataset_name=dataset_name,
                                              top_acc=top_acc
                                              )
    if tar_acc_gradual_decent is True:
        if flop_expected<=1:
            flop_expected=int(flop_original_net*flop_expected)
        flop_drop_expected = flop_original_net - flop_expected
        acc_drop_tolerance = original_accuracy - target_accuracy

    conv_list, filter_num, filter_num_lower_bound=get_information_for_pruned_conv(net_entity,net_name,filter_preserve_ratio)
    num_filters_to_prune_at_most=[]
    for i in range(len(filter_num)):
        num_filters_to_prune_at_most+=[filter_num[i]-filter_num_lower_bound[i]]

    prune_rate=initial_prune_rate
    while True:
        print('{} start round {} of filter pruning.'.format(datetime.now(), round))
        print('{} current prune_rate:{}'.format(datetime.now(),prune_rate))
        #todo 待考虑
        net_entity.reset_mask()
        if round <= round_for_train:
            dead_filter_index, module_list, neural_list, FIRE_tmp = evaluate.find_useless_filters_data_version(
                net=net_entity,
                batch_size=16,                                                                                                  #this function need to run on sigle gpu
                percent_of_inactive_filter=prune_rate,
                dead_or_inactive='inactive',
                dataset_name=dataset_name,
                max_data_to_test=max_data_to_test,
                num_filters_to_prune_at_most=num_filters_to_prune_at_most,
            )
            num_test_images = math.ceil(max_data_to_test / 16) * 16                                                             #Warning, this goes wrong if dataset_size is smaller
            if not os.path.exists(os.path.join(checkpoint_path, 'dead_neural')):
                os.makedirs(os.path.join(checkpoint_path, 'dead_neural'), exist_ok=True)

            checkpoint = {'prune_rate': prune_rate, 'module_list': module_list,
                          'neural_list': neural_list, 'state_dict': net_entity.state_dict(),
                          'num_test_images':num_test_images}
            checkpoint.update(storage.get_net_information(net_entity, dataset_name, net_name))
            torch.save(checkpoint,
                       os.path.join(checkpoint_path, 'dead_neural/round %d.tar' % round)
                       )
        else :
            print('hahaha')
            return

        #todo:round>round_for_train,use filter feature extractor

        '''卷积核剪枝'''
        for i in conv_list:
            #todo 待考虑
            # # ensure the number of filters pruned will not be too large for one time
            # if type(max_filters_pruned_for_one_time) is list:
            #     num_filters_to_prune_max = filter_num[i] * max_filters_pruned_for_one_time[i]
            # else:
            #     num_filters_to_prune_max = filter_num[i] * max_filters_pruned_for_one_time
            # if num_filters_to_prune_max < len(dead_filter_index[i]):
            #     dead_filter_index[i] = dead_filter_index[i][:int(num_filters_to_prune_max)]
            # ensure the lower bound of filter number
            if filter_num[i] - len(dead_filter_index[i]) < filter_num_lower_bound[i]:
                raise Exception('i think something is wrong')
                dead_filter_index[i] = dead_filter_index[i][:filter_num[i] - filter_num_lower_bound[i]]

            print('layer {}: has {} filters, prunes {} filters, remains {} filters.'.
                  format(i, filter_num[i], len(dead_filter_index[i]),filter_num[i]-len(dead_filter_index[i])))
            net_entity.mask_filters(i,dead_filter_index[i])

        flop_after_prune = measure_model(net_entity.prune(), dataset_name)
        net_compressed= (flop_after_prune != flop_before_prune)
        if net_compressed is False:
            round -= 1
            print('{} round {} did not prune any filters. Restart.'.format(datetime.now(), round + 1))
            prune_rate += 0.02
            continue


        if tar_acc_gradual_decent is True:  # decent the target_accuracy
            flop_reduced = flop_original_net - flop_after_prune
            target_accuracy = original_accuracy - acc_drop_tolerance * (flop_reduced / flop_drop_expected)
            print('{} current target accuracy:{}'.format(datetime.now(), target_accuracy))

        success = False
        training_round=0
        while not success:
            old_net = deepcopy(net)
            training_round+=1
            success = train.train(net=net,
                                  net_name=net_name,
                                  exp_name=exp_name,
                                  num_epochs=num_epoch,
                                  target_accuracy=target_accuracy,
                                  learning_rate=learning_rate,
                                  load_net=False,
                                  evaluate_step=evaluate_step,
                                  dataset_name=dataset_name,
                                  optimizer=optimizer,
                                  batch_size=batch_size,
                                  learning_rate_decay=learning_rate_decay,
                                  learning_rate_decay_factor=learning_rate_decay_factor,
                                  weight_decay=weight_decay,
                                  learning_rate_decay_epoch=learning_rate_decay_epoch,
                                  test_net=True,
                                  top_acc=top_acc,
                                  no_grad=['mask']
                                  )
            if success:
                prune_rate+= delta_prune_rate
                round += 1
                flop_before_prune=flop_after_prune
                net_entity.reset_mask()
                if num_epoch_after_finetune ==0:
                    break
                train.train(net=net,
                            net_name=net_name,
                            exp_name=exp_name,
                            num_epochs=num_epoch_after_finetune,
                            learning_rate=learning_rate,
                            load_net=False,
                            evaluate_step=evaluate_step,
                            dataset_name=dataset_name,
                            optimizer=optimizer,
                            batch_size=batch_size,
                            test_net=True,
                            top_acc=top_acc,
                            no_grad=['mask']
                            )
            else:
                net = old_net
                if isinstance(net, nn.DataParallel):
                    net_entity = net.module
                else:
                    net_entity = net
                net_entity.reset_mask()
                if max_training_round == training_round:
                    prune_rate-=delta_prune_rate/2
                    if prune_rate<=0:
                        print('{} failed to prune the net, pruning stop.'.format(datetime.now()))
                        return
                    break
示例#2
0
                                         groups=conv.groups,
                                         bias=(conv.bias is not None))
        self.weight = conv.weight
        if self.bias is not None:
            self.bias = conv.bias
        self.mask = mask
        self.front_mask = front_mask

    def forward(self, input):
        #mask the pruned filter and channel
        masked_weight = self.weight * self.mask.detach().reshape(
            (-1, 1, 1, 1)) * self.front_mask.detach().reshape((1, -1, 1, 1))
        masked_bias = self.bias * self.mask.detach()
        out = F.conv2d(input, masked_weight, masked_bias, self.stride,
                       self.padding, self.dilation, self.groups)

        return out


if __name__ == "__main__":
    net = NetWithMask(dataset_name='imagenet', net_name='vgg16_bn')
    #
    # tmp=net.prune()
    # net.mask_filters(layer_index=0,filter_index=[0,2,5])

    from framework import evaluate, data_loader
    dl = data_loader.create_validation_loader(batch_size=512,
                                              num_workers=2,
                                              dataset_name='cifar10')
    evaluate.evaluate_net(net, dl, save_net=False)
示例#3
0
def speed_up_pruned_net():
    fontsize = 15

    checkpoint = torch.load(
        '../data/baseline/vgg16_bn_cifar10,accuracy=0.941.tar')
    net_original = storage.restore_net(checkpoint)

    checkpoint = torch.load(
        '../data/baseline/resnet56_cifar10,accuracy=0.93280.tar')
    net_original = resnet_cifar.resnet56()
    net_original.load_state_dict(checkpoint['state_dict'])

    checkpoint = torch.load(
        '../data/model_saved/vgg16bn_cifar10_realdata_regressor6_大幅度/checkpoint/flop=39915982,accuracy=0.93200.tar'
    )

    checkpoint = torch.load(
        '../data/model_saved/resnet56_cifar10_regressor_prunedBaseline2/checkpoint/flop=36145802,accuracy=0.92110.tar'
    )
    net_pruned = storage.restore_net(checkpoint)
    net_pruned.load_state_dict(checkpoint['state_dict'])

    # batch_size=[256,512,1024]
    num_workers = [i for i in range(4, 5)]
    batch_size = [300, 600, 1000, 1600]

    device_list = [torch.device('cuda')]  #
    # device_list=[torch.device('cpu')]
    for num_worker in num_workers:
        time_original = []
        time_pruned = []
        for d in device_list:
            for bs in batch_size:
                net_original.to(d)
                net_pruned.to(d)

                dl = data_loader.create_validation_loader(
                    batch_size=bs,
                    num_workers=num_worker,
                    dataset_name='cifar10')
                start_time = time.time()
                evaluate.evaluate_net(net=net_original,
                                      data_loader=dl,
                                      save_net=False,
                                      device=d)
                end_time = time.time()
                time_original.append(end_time - start_time)
                del dl

                dl = data_loader.create_validation_loader(
                    batch_size=bs,
                    num_workers=num_worker,
                    dataset_name='cifar10')
                start_time = time.time()
                evaluate.evaluate_net(net=net_pruned,
                                      data_loader=dl,
                                      save_net=False,
                                      device=d)
                end_time = time.time()
                time_pruned.append(end_time - start_time)
                del dl

        print('time before pruned:', time_original)
        print('time after pruned:', time_pruned)
        acceleration = np.array(time_original) / np.array(time_pruned)
        baseline = np.ones(shape=len(batch_size))
        x_tick = range(len(baseline))

        plt.figure()
        plt.bar(x_tick, acceleration, color='blue', hatch='//')  #,label='GPU')
        # plt.bar(x_tick[len(batch_size):], acceleration[len(batch_size):], color='grey', hatch='\\', label='CPU')
        # plt.bar(x_tick,baseline,color='red',hatch='*',label='Baseline')
        plt.xticks(x_tick, batch_size, fontsize=fontsize)
        plt.yticks(fontsize=fontsize)
        for x, y in enumerate(list(acceleration)):
            plt.text(x, y + 0.1, '%.2f x' % y, ha='center', fontsize=fontsize)
        plt.ylim([0, np.max(acceleration) + 0.5])
        plt.xlabel('Batch-Size', fontsize=fontsize)
        plt.ylabel('Speed-Up', fontsize=fontsize)
        # plt.legend(loc='upper left')
        plt.savefig('resnet_gpu_speed_up.eps', format='eps')
        # plt.savefig(str(num_worker)+'speed_up.jpg')
        plt.show()
        print()
示例#4
0
# def conversion(dataset_name,net_name,checkpoint_path='',checkpoint=None):

#转化之前的cifar上的resnet
#     checkpoint=torch.load(checkpoint_path)
#     # net=checkpoint.pop('net')
#     net=resnet_cifar.resnet32()
#     checkpoint['state_dict']['fc.weight']=checkpoint['state_dict'].pop('linear.weight')
#     checkpoint['state_dict']['fc.bias']=checkpoint['state_dict'].pop('linear.bias')
#
#     net.load_state_dict(checkpoint['state_dict'])
#     checkpoint.update(get_net_information(net,dataset_name,net_name))
#     torch.save(checkpoint,checkpoint_path)

if __name__ == "__main__":
    # conversion(checkpoint_path='../data/baseline/resnet32_cifar10,accuracy=0.92380.tar',net_name='resnet56',dataset_name='cifar10')

    checkpoint = torch.load(
        '/home/victorfang/model_pytorch/data/baseline/resnet56_cifar10,accuracy=0.94230.tar'
    )

    net = resnet_cifar.resnet56()
    # net.load_state_dict(checkpoint['state_dict'])
    c = get_net_information(net=net,
                            dataset_name='cifar10',
                            net_name='resnet56')
    net = restore_net(checkpoint, True)
    from framework import evaluate, data_loader
    evaluate.evaluate_net(
        net, data_loader.create_validation_loader(512, 4, 'cifar10'), False)
    # c=get_net_information(net=net,dataset_name=dataset_name,net_name='resnet50')
示例#5
0
#     c_new = {'highest_accuracy': c_original['highest_accuracy'],
#              'state_dict': new_state_dict}
#
#     torch.save(c_new, path)

if __name__ == "__main__":
    import torch
    from network import storage
    from framework import evaluate, data_loader

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(
        '/home/victorfang/model_pytorch/data/baseline/resnet56_cifar100_0.71580.tar'
    )
    # c_sample=torch.load('/home/victorfang/model_pytorch/data/baseline/vgg16_bn_cifar10,accuracy=0.941.tar')
    net = resnet56(num_classes=200)
    net.load_state_dict(checkpoint['state_dict'])
    net.to(device)

    # checkpoint.update(storage.get_net_information(net=net,dataset_name='tiny_imagenet',net_name='resnet18'))
    # checkpoint.pop('net')
    # torch.save(checkpoint,'/home/victorfang/model_pytorch/data/baseline/resnet18_tinyimagenet_v2_0.72990.tar')
    # net=storage.restore_net(checkpoint=checkpoint,pretrained=True)
    # net=nn.DataParallel(net)
    evaluate.evaluate_net(net=net,
                          data_loader=data_loader.create_validation_loader(
                              512, 8, 'cifar100'),
                          save_net=False,
                          dataset_name='cifar100')

    print()
示例#6
0
def train(
        net,
        net_name,
        exp_name='',
        dataset_name='imagenet',
        train_loader=None,
        validation_loader=None,
        learning_rate=conf.learning_rate,
        num_epochs=conf.num_epochs,
        batch_size=conf.batch_size,
        evaluate_step=conf.evaluate_step,
        load_net=True,
        test_net=False,
        root_path=conf.root_path,
        checkpoint_path=None,
        momentum=conf.momentum,
        num_workers=conf.num_workers,
        learning_rate_decay=False,
        learning_rate_decay_factor=conf.learning_rate_decay_factor,
        learning_rate_decay_epoch=conf.learning_rate_decay_epoch,
        weight_decay=conf.weight_decay,
        target_accuracy=1.0,
        optimizer=optim.SGD,
        top_acc=1,
        criterion=nn.CrossEntropyLoss(),  # 损失函数默为交叉熵,多用于多分类问题
        no_grad=[],
        scheduler_name='MultiStepLR',
        eta_min=0,
        #todo:tmp!!!
        data_parallel=False):
    '''

    :param net: net to be trained
    :param net_name: name of the net
    :param exp_name: name of the experiment
    :param dataset_name: name of the dataset
    :param train_loader: data_loader for training. If not provided, a data_loader will be created based on dataset_name
    :param validation_loader: data_loader for validation. If not provided, a data_loader will be created based on dataset_name
    :param learning_rate: initial learning rate
    :param learning_rate_decay: boolean, if true, the learning rate will decay based on the params provided.
    :param learning_rate_decay_factor: float. learning_rate*=learning_rate_decay_factor, every time it decay.
    :param learning_rate_decay_epoch: list[int], the specific epoch that the learning rate will decay.
    :param num_epochs: max number of epochs for training
    :param batch_size:
    :param evaluate_step: how often will the net be tested on validation set. At least one test every epoch is guaranteed
    :param load_net: boolean, whether loading net from previous checkpoint. The newest checkpoint will be selected.
    :param test_net:boolean, if true, the net will be tested before training.
    :param root_path:
    :param checkpoint_path:
    :param momentum:
    :param num_workers:
    :param weight_decay:
    :param target_accuracy:float, the training will stop once the net reached target accuracy
    :param optimizer:
    :param top_acc: can be 1 or 5
    :param criterion: loss function
    :param no_grad: list containing names of the modules that do not need to be trained
    :param scheduler_name
    :param eta_min: for CosineAnnealingLR
    :return:
    '''
    success = True  #if the trained net reaches target accuracy
    # gpu or not
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('using: ', end='')
    if torch.cuda.is_available():
        print(torch.cuda.device_count(), ' * ', end='')
        print(torch.cuda.get_device_name(torch.cuda.current_device()))
    else:
        print(device)

    #prepare the data
    if dataset_name is 'imagenet':
        mean = conf.imagenet['mean']
        std = conf.imagenet['std']
        train_set_path = conf.imagenet['train_set_path']
        train_set_size = conf.imagenet['train_set_size']
        validation_set_path = conf.imagenet['validation_set_path']
        default_image_size = conf.imagenet['default_image_size']
    elif dataset_name is 'cifar10':
        train_set_size = conf.cifar10['train_set_size']
        mean = conf.cifar10['mean']
        std = conf.cifar10['std']
        train_set_path = conf.cifar10['dataset_path']
        validation_set_path = conf.cifar10['dataset_path']
        default_image_size = conf.cifar10['default_image_size']
    elif dataset_name is 'tiny_imagenet':
        train_set_size = conf.tiny_imagenet['train_set_size']
        mean = conf.tiny_imagenet['mean']
        std = conf.tiny_imagenet['std']
        train_set_path = conf.tiny_imagenet['train_set_path']
        validation_set_path = conf.tiny_imagenet['validation_set_path']
        default_image_size = conf.tiny_imagenet['default_image_size']
    elif dataset_name is 'cifar100':
        train_set_size = conf.cifar100['train_set_size']
        mean = conf.cifar100['mean']
        std = conf.cifar100['std']
        train_set_path = conf.cifar100['dataset_path']
        validation_set_path = conf.cifar100['dataset_path']
        default_image_size = conf.cifar100['default_image_size']
    if train_loader is None:
        train_loader = data_loader.create_train_loader(
            dataset_path=train_set_path,
            default_image_size=default_image_size,
            mean=mean,
            std=std,
            batch_size=batch_size,
            num_workers=num_workers,
            dataset_name=dataset_name)
    if validation_loader is None:
        validation_loader = data_loader.create_validation_loader(
            dataset_path=validation_set_path,
            default_image_size=default_image_size,
            mean=mean,
            std=std,
            batch_size=batch_size,
            num_workers=num_workers,
            dataset_name=dataset_name)

    if checkpoint_path is None:
        checkpoint_path = os.path.join(root_path, 'model_saved', exp_name,
                                       'checkpoint')
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path, exist_ok=True)

    #get the latest checkpoint
    lists = os.listdir(checkpoint_path)
    file_new = checkpoint_path
    if len(lists) > 0:
        lists.sort(key=lambda fn: os.path.getmtime(checkpoint_path + "/" + fn)
                   )  # 按时间排序
        file_new = os.path.join(checkpoint_path,
                                lists[-1])  # 获取最新的文件保存到file_new

    sample_num = 0
    if os.path.isfile(file_new):
        if load_net:
            checkpoint = torch.load(file_new)
            print('{} load net from previous checkpoint:{}'.format(
                datetime.now(), file_new))
            # net.load_state_dict(checkpoint['state_dict'])
            net = storage.restore_net(checkpoint,
                                      pretrained=True,
                                      data_parallel=data_parallel)
            sample_num = checkpoint['sample_num']

    if test_net:
        print('{} test the net'.format(
            datetime.now()))  #no previous checkpoint
        net_test = copy.deepcopy(net)
        accuracy = evaluate.evaluate_net(net_test,
                                         validation_loader,
                                         save_net=True,
                                         checkpoint_path=checkpoint_path,
                                         sample_num=sample_num,
                                         target_accuracy=target_accuracy,
                                         dataset_name=dataset_name,
                                         top_acc=top_acc,
                                         net_name=net_name,
                                         exp_name=exp_name)
        del net_test

        if accuracy >= target_accuracy:
            print('{} net reached target accuracy.'.format(datetime.now()))
            return success

    #ensure the net will be evaluated despite the inappropriate evaluate_step
    if evaluate_step > math.ceil(train_set_size / batch_size) - 1:
        evaluate_step = math.ceil(train_set_size / batch_size) - 1

    optimizer = prepare_optimizer(net, optimizer, no_grad, momentum,
                                  learning_rate, weight_decay)
    if learning_rate_decay:
        if scheduler_name == 'MultiStepLR':
            scheduler = lr_scheduler.MultiStepLR(
                optimizer,
                milestones=learning_rate_decay_epoch,
                gamma=learning_rate_decay_factor,
                last_epoch=ceil(sample_num / train_set_size))
        elif scheduler_name == 'CosineAnnealingLR':
            scheduler = lr_scheduler.CosineAnnealingLR(
                optimizer,
                num_epochs,
                eta_min=eta_min,
                last_epoch=ceil(sample_num / train_set_size))
    print("{} Start training ".format(datetime.now()) + net_name + "...")
    for epoch in range(math.floor(sample_num / train_set_size), num_epochs):
        print("{} Epoch number: {}".format(datetime.now(), epoch + 1))
        net.train()
        # one epoch for one loop
        for step, data in enumerate(train_loader, 0):
            if sample_num / train_set_size == epoch + 1:  #one epoch of training finished
                net_test = copy.deepcopy(net)
                accuracy = evaluate.evaluate_net(
                    net_test,
                    validation_loader,
                    save_net=True,
                    checkpoint_path=checkpoint_path,
                    sample_num=sample_num,
                    target_accuracy=target_accuracy,
                    dataset_name=dataset_name,
                    top_acc=top_acc,
                    net_name=net_name,
                    exp_name=exp_name)
                del net_test
                if accuracy >= target_accuracy:
                    print('{} net reached target accuracy.'.format(
                        datetime.now()))
                    return success
                break

            # 准备数据
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            sample_num += int(images.shape[0])

            # if learning_rate_decay:
            #     exponential_decay_learning_rate(optimizer=optimizer,
            #                                     sample_num=sample_num,
            #                                     learning_rate_decay_factor=learning_rate_decay_factor,
            #                                     train_set_size=train_set_size,
            #                                     learning_rate_decay_epoch=learning_rate_decay_epoch,
            #                                     batch_size=batch_size)

            optimizer.zero_grad()
            # forward + backward
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if step % 60 == 0:
                print('{} loss is {}'.format(datetime.now(), float(loss.data)))

            if step % evaluate_step == 0 and step != 0:
                net_test = copy.deepcopy(net)
                accuracy = evaluate.evaluate_net(
                    net_test,
                    validation_loader,
                    save_net=True,
                    checkpoint_path=checkpoint_path,
                    sample_num=sample_num,
                    target_accuracy=target_accuracy,
                    dataset_name=dataset_name,
                    top_acc=top_acc,
                    net_name=net_name,
                    exp_name=exp_name)
                del net_test
                if accuracy >= target_accuracy:
                    print('{} net reached target accuracy.'.format(
                        datetime.now()))
                    return success
                accuracy = float(accuracy)
                print('{} continue training'.format(datetime.now()))
        if learning_rate_decay:
            scheduler.step()
            print(optimizer.state_dict()['param_groups'][0]['lr'])

    print("{} Training finished. Saving net...".format(datetime.now()))
    net_test = copy.deepcopy(net)
    flop_num = measure_flops.measure_model(net=net_test,
                                           dataset_name=dataset_name,
                                           print_flop=False)
    accuracy = evaluate.evaluate_net(net_test,
                                     validation_loader,
                                     save_net=True,
                                     checkpoint_path=checkpoint_path,
                                     sample_num=sample_num,
                                     target_accuracy=target_accuracy,
                                     dataset_name=dataset_name,
                                     top_acc=top_acc,
                                     net_name=net_name,
                                     exp_name=exp_name)
    accuracy = float(accuracy)
    checkpoint = {
        'highest_accuracy': accuracy,
        'state_dict': net.state_dict(),
        'sample_num': sample_num,
        'flop_num': flop_num
    }
    checkpoint.update(storage.get_net_information(net, dataset_name, net_name))
    torch.save(
        checkpoint,
        '%s/flop=%d,accuracy=%.5f.tar' % (checkpoint_path, flop_num, accuracy))
    print("{} net saved at sample num = {}".format(datetime.now(), sample_num))
    return not success