Ejemplo n.º 1
0
def plot_dead_filter_num_with_different_fdt():
    # print()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #
    # # net=create_net.vgg_cifar10()
    checkpoint = torch.load(
        '../data/baseline/resnet56_cifar10,accuracy=0.94230.tar')
    net = resnet_cifar.resnet56().to(device)
    net.load_state_dict(checkpoint['state_dict'])

    val_loader = data_loader.create_validation_loader(batch_size=500,
                                                      num_workers=6,
                                                      dataset_name='cifar10')
    relu_list, neural_list = evaluate.check_ReLU_alive(net=net,
                                                       neural_dead_times=8000,
                                                       data_loader=val_loader)

    # net=vgg.vgg16_bn(pretrained=False)
    # checkpoint=torch.load('/home/victorfang/Desktop/vgg16_bn_imagenet_deadReLU.tar')
    # net=resnet.resnet34(pretrained=True)
    # checkpoint=torch.load('/home/victorfang/Desktop/resnet34_imagenet_DeadReLU.tar')
    # neural_list=checkpoint['neural_list']
    # relu_list=checkpoint['relu_list']

    neural_dead_times = 8000
    # neural_dead_times=40000
    fdt_list = [0.001 * i for i in range(1, 1001)]
    dead_filter_num = []
    for fdt in fdt_list:
        dead_filter_num.append(
            dead_filter_statistics(net=net,
                                   neural_list=neural_list,
                                   neural_dead_times=neural_dead_times,
                                   filter_FIRE=fdt,
                                   relu_list=relu_list))
        if fdt == 0.8:
            print()
    plt.figure()
    plt.title('df')
    plt.plot(fdt_list, dead_filter_num)
    plt.xlabel('filter activation ratio')
    plt.ylabel('number of filters')
    plt.legend()
    plt.show()
Ejemplo n.º 2
0
def dead_neural_rate():
    # checkpoint=torch.load('/home/victorfang/Desktop/vgg16_bn_imagenet_deadReLU.tar')
    # checkpoint=torch.load('/home/victorfang/Desktop/resnet34_imagenet_DeadReLU.tar')
    # neural_list=checkpoint['neural_list']
    # relu_list=checkpoint['relu_list']

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(
        '../data/baseline/resnet56_cifar10,accuracy=0.94230.tar')
    net = resnet_cifar.resnet56().to(device)
    net.load_state_dict(checkpoint['state_dict'])

    # net=create_net.vgg_cifar10()
    val_loader = data_loader.create_validation_loader(batch_size=1000,
                                                      num_workers=6,
                                                      dataset_name='cifar10')
    # train_loader=data_loader.create_train_loader(batch_size=1600,num_workers=6,dataset_name='cifar10')
    #
    relu_list, neural_list = evaluate.check_ReLU_alive(net=net,
                                                       neural_dead_times=10000,
                                                       data_loader=val_loader)
    # ndt_list=[i for i in range(35000,51000,1000)]
    ndt_list = [i for i in range(6000, 11000, 1000)]
    dead_rate = []
    for ndt in ndt_list:
        print(ndt)
        dead_rate.append(
            evaluate.cal_dead_neural_rate(neural_dead_times=ndt,
                                          neural_list_temp=neural_list))

    plt.figure()
    plt.title('df')
    plt.plot(ndt_list, dead_rate)
    plt.xlabel('neural dead times')
    plt.ylabel('neuron dead rate%')
    plt.legend()
    plt.show()
Ejemplo n.º 3
0
def dynamic_prune_inactive_neural_with_feature_extractor(net,
                                                         net_name,
                                                         exp_name,
                                                         target_accuracy,
                                                         initial_prune_rate,
                                                         delta_prune_rate=0.02,
                                                         round_for_train=2,
                                                         tar_acc_gradual_decent=False,
                                                         flop_expected=None,
                                                         dataset_name='imagenet',
                                                         batch_size=conf.batch_size,
                                                         num_workers=conf.num_workers,
                                                         learning_rate=0.01,
                                                         evaluate_step=1200,
                                                         num_epoch=450,
                                                         num_epoch_after_finetune=20,
                                                         filter_preserve_ratio=0.3,
                                                         # max_filters_pruned_for_one_time=0.5,
                                                         optimizer=optim.Adam,
                                                         learning_rate_decay=False,
                                                         learning_rate_decay_factor=conf.learning_rate_decay_factor,
                                                         weight_decay=conf.weight_decay,
                                                         learning_rate_decay_epoch=conf.learning_rate_decay_epoch,
                                                         max_training_round=9999,
                                                         round=1,
                                                         top_acc=1,
                                                         max_data_to_test=10000,
                                                         **kwargs
                                                         ):
    '''

       :param net:
       :param net_name:
       :param exp_name:
       :param target_accuracy:
       :param initial_prune_rate:
       :param delta_prune_rate: increase/decrease of prune_rate after each round
       :param round_for_train:
       :param tar_acc_gradual_decent:
       :param flop_expected:
       :param dataset_name:
       :param batch_size:
       :param num_workers:
       :param optimizer:
       :param learning_rate:
       :param evaluate_step:
       :param num_epoch:
       :param num_epoch_after_finetune: epoch to train after fine-tune the pruned net
       :param filter_preserve_ratio:
       :param max_filters_pruned_for_one_time:
       :param learning_rate_decay:
       :param learning_rate_decay_factor:
       :param weight_decay:
       :param learning_rate_decay_epoch:
       :param max_training_round:if the net can't reach target accuracy in max_training_round , the program stop.
       :param top_acc:
       :param kwargs:
       :return:
       '''

    # save the output to log
    print('save log in:' + os.path.join(conf.root_path, 'model_saved', exp_name, 'log.txt'))
    if not os.path.exists(os.path.join(conf.root_path, 'model_saved', exp_name)):
        os.makedirs(os.path.join(conf.root_path, 'model_saved', exp_name), exist_ok=True)
    sys.stdout = logger.Logger(os.path.join(conf.root_path, 'model_saved', exp_name, 'log.txt'), sys.stdout)
    sys.stderr = logger.Logger(os.path.join(conf.root_path, 'model_saved', exp_name, 'log.txt'),
                               sys.stderr)  # redirect std err, if necessary

    print('net:', net)
    print('net_name:', net_name)
    print('exp_name:', exp_name)
    print('target_accuracy:', target_accuracy)
    print('initial_prune_rate:', initial_prune_rate)
    print('delta_prune_rate:',delta_prune_rate)
    print('round_for_train:', round_for_train)
    print('tar_acc_gradual_decent:', tar_acc_gradual_decent)
    print('flop_expected:', flop_expected)
    print('dataset_name:', dataset_name)
    print('batch_size:', batch_size)
    print('num_workers:', num_workers)
    print('optimizer:', optimizer)
    print('learning_rate:', learning_rate)
    print('evaluate_step:', evaluate_step)
    print('num_epoch:', num_epoch)
    print('num_epoch_after_finetune:',num_epoch_after_finetune)
    print('filter_preserve_ratio:', filter_preserve_ratio)
    # print('max_filters_pruned_for_one_time:', max_filters_pruned_for_one_time)
    print('learning_rate_decay:', learning_rate_decay)
    print('learning_rate_decay_factor:', learning_rate_decay_factor)
    print('weight_decay:', weight_decay)
    print('learning_rate_decay_epoch:', learning_rate_decay_epoch)
    print('max_training_round:', max_training_round)
    print('top_acc:', top_acc)
    print('round:', round)
    print(kwargs)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('using: ', end='')
    if torch.cuda.is_available():
        print(torch.cuda.device_count(),' * ',end='')
        print(torch.cuda.get_device_name(torch.cuda.current_device()))
    else:
        print(device)

    checkpoint_path=os.path.join(conf.root_path, 'model_saved', exp_name)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path, exist_ok=True)

    validation_loader = data_loader.create_validation_loader(
        batch_size=batch_size,
        num_workers=num_workers,
        dataset_name=dataset_name,
    )

    if isinstance(net,nn.DataParallel):
        net_entity=net.module
    else:
        net_entity=net

    flop_original_net=flop_before_prune = measure_model(net_entity.prune(), dataset_name)

    original_accuracy = evaluate.evaluate_net(net=net,
                                              data_loader=validation_loader,
                                              save_net=False,
                                              dataset_name=dataset_name,
                                              top_acc=top_acc
                                              )
    if tar_acc_gradual_decent is True:
        if flop_expected<=1:
            flop_expected=int(flop_original_net*flop_expected)
        flop_drop_expected = flop_original_net - flop_expected
        acc_drop_tolerance = original_accuracy - target_accuracy

    conv_list, filter_num, filter_num_lower_bound=get_information_for_pruned_conv(net_entity,net_name,filter_preserve_ratio)
    num_filters_to_prune_at_most=[]
    for i in range(len(filter_num)):
        num_filters_to_prune_at_most+=[filter_num[i]-filter_num_lower_bound[i]]

    prune_rate=initial_prune_rate
    while True:
        print('{} start round {} of filter pruning.'.format(datetime.now(), round))
        print('{} current prune_rate:{}'.format(datetime.now(),prune_rate))
        #todo 待考虑
        net_entity.reset_mask()
        if round <= round_for_train:
            dead_filter_index, module_list, neural_list, FIRE_tmp = evaluate.find_useless_filters_data_version(
                net=net_entity,
                batch_size=16,                                                                                                  #this function need to run on sigle gpu
                percent_of_inactive_filter=prune_rate,
                dead_or_inactive='inactive',
                dataset_name=dataset_name,
                max_data_to_test=max_data_to_test,
                num_filters_to_prune_at_most=num_filters_to_prune_at_most,
            )
            num_test_images = math.ceil(max_data_to_test / 16) * 16                                                             #Warning, this goes wrong if dataset_size is smaller
            if not os.path.exists(os.path.join(checkpoint_path, 'dead_neural')):
                os.makedirs(os.path.join(checkpoint_path, 'dead_neural'), exist_ok=True)

            checkpoint = {'prune_rate': prune_rate, 'module_list': module_list,
                          'neural_list': neural_list, 'state_dict': net_entity.state_dict(),
                          'num_test_images':num_test_images}
            checkpoint.update(storage.get_net_information(net_entity, dataset_name, net_name))
            torch.save(checkpoint,
                       os.path.join(checkpoint_path, 'dead_neural/round %d.tar' % round)
                       )
        else :
            print('hahaha')
            return

        #todo:round>round_for_train,use filter feature extractor

        '''卷积核剪枝'''
        for i in conv_list:
            #todo 待考虑
            # # ensure the number of filters pruned will not be too large for one time
            # if type(max_filters_pruned_for_one_time) is list:
            #     num_filters_to_prune_max = filter_num[i] * max_filters_pruned_for_one_time[i]
            # else:
            #     num_filters_to_prune_max = filter_num[i] * max_filters_pruned_for_one_time
            # if num_filters_to_prune_max < len(dead_filter_index[i]):
            #     dead_filter_index[i] = dead_filter_index[i][:int(num_filters_to_prune_max)]
            # ensure the lower bound of filter number
            if filter_num[i] - len(dead_filter_index[i]) < filter_num_lower_bound[i]:
                raise Exception('i think something is wrong')
                dead_filter_index[i] = dead_filter_index[i][:filter_num[i] - filter_num_lower_bound[i]]

            print('layer {}: has {} filters, prunes {} filters, remains {} filters.'.
                  format(i, filter_num[i], len(dead_filter_index[i]),filter_num[i]-len(dead_filter_index[i])))
            net_entity.mask_filters(i,dead_filter_index[i])

        flop_after_prune = measure_model(net_entity.prune(), dataset_name)
        net_compressed= (flop_after_prune != flop_before_prune)
        if net_compressed is False:
            round -= 1
            print('{} round {} did not prune any filters. Restart.'.format(datetime.now(), round + 1))
            prune_rate += 0.02
            continue


        if tar_acc_gradual_decent is True:  # decent the target_accuracy
            flop_reduced = flop_original_net - flop_after_prune
            target_accuracy = original_accuracy - acc_drop_tolerance * (flop_reduced / flop_drop_expected)
            print('{} current target accuracy:{}'.format(datetime.now(), target_accuracy))

        success = False
        training_round=0
        while not success:
            old_net = deepcopy(net)
            training_round+=1
            success = train.train(net=net,
                                  net_name=net_name,
                                  exp_name=exp_name,
                                  num_epochs=num_epoch,
                                  target_accuracy=target_accuracy,
                                  learning_rate=learning_rate,
                                  load_net=False,
                                  evaluate_step=evaluate_step,
                                  dataset_name=dataset_name,
                                  optimizer=optimizer,
                                  batch_size=batch_size,
                                  learning_rate_decay=learning_rate_decay,
                                  learning_rate_decay_factor=learning_rate_decay_factor,
                                  weight_decay=weight_decay,
                                  learning_rate_decay_epoch=learning_rate_decay_epoch,
                                  test_net=True,
                                  top_acc=top_acc,
                                  no_grad=['mask']
                                  )
            if success:
                prune_rate+= delta_prune_rate
                round += 1
                flop_before_prune=flop_after_prune
                net_entity.reset_mask()
                if num_epoch_after_finetune ==0:
                    break
                train.train(net=net,
                            net_name=net_name,
                            exp_name=exp_name,
                            num_epochs=num_epoch_after_finetune,
                            learning_rate=learning_rate,
                            load_net=False,
                            evaluate_step=evaluate_step,
                            dataset_name=dataset_name,
                            optimizer=optimizer,
                            batch_size=batch_size,
                            test_net=True,
                            top_acc=top_acc,
                            no_grad=['mask']
                            )
            else:
                net = old_net
                if isinstance(net, nn.DataParallel):
                    net_entity = net.module
                else:
                    net_entity = net
                net_entity.reset_mask()
                if max_training_round == training_round:
                    prune_rate-=delta_prune_rate/2
                    if prune_rate<=0:
                        print('{} failed to prune the net, pruning stop.'.format(datetime.now()))
                        return
                    break
Ejemplo n.º 4
0
                                         groups=conv.groups,
                                         bias=(conv.bias is not None))
        self.weight = conv.weight
        if self.bias is not None:
            self.bias = conv.bias
        self.mask = mask
        self.front_mask = front_mask

    def forward(self, input):
        #mask the pruned filter and channel
        masked_weight = self.weight * self.mask.detach().reshape(
            (-1, 1, 1, 1)) * self.front_mask.detach().reshape((1, -1, 1, 1))
        masked_bias = self.bias * self.mask.detach()
        out = F.conv2d(input, masked_weight, masked_bias, self.stride,
                       self.padding, self.dilation, self.groups)

        return out


if __name__ == "__main__":
    net = NetWithMask(dataset_name='imagenet', net_name='vgg16_bn')
    #
    # tmp=net.prune()
    # net.mask_filters(layer_index=0,filter_index=[0,2,5])

    from framework import evaluate, data_loader
    dl = data_loader.create_validation_loader(batch_size=512,
                                              num_workers=2,
                                              dataset_name='cifar10')
    evaluate.evaluate_net(net, dl, save_net=False)
Ejemplo n.º 5
0
def plot_dead_neuron_filter_number(
    neural_dead_times=8000,
    dataset_name='cifar10',
):
    fontsize = 17
    label_fontsize = 24
    tick_fontsize = 20

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    checkpoint = torch.load(
        '../data/baseline/vgg16_bn_cifar10,accuracy=0.941.tar')
    vgg16 = storage.restore_net(checkpoint)
    vgg16.load_state_dict(checkpoint['state_dict'])

    checkpoint = torch.load(
        '../data/baseline/resnet56_cifar10,accuracy=0.94230.tar')
    resnet56 = resnet_cifar.resnet56().to(device)
    resnet56.load_state_dict(checkpoint['state_dict'])

    vgg16_imagenet = vgg.vgg16_bn(pretrained=True).to(device)
    checkpoint = torch.load(
        '/home/victorfang/Desktop/vgg16_bn_imagenet_deadReLU.tar')
    relu_list_imagenet = checkpoint['relu_list']
    neural_list_imagenet = checkpoint['neural_list']

    loader = data_loader.create_validation_loader(batch_size=100,
                                                  num_workers=1,
                                                  dataset_name=dataset_name)
    # loader=data_loader.create_validation_loader(batch_size=1000,num_workers=8,dataset_name='cifar10_trainset')

    relu_list_vgg, neural_list_vgg = evaluate.check_ReLU_alive(
        net=vgg16,
        neural_dead_times=neural_dead_times,
        data_loader=loader,
        max_data_to_test=10000)
    relu_list_resnet, neural_list_resnet = evaluate.check_ReLU_alive(
        net=resnet56,
        neural_dead_times=neural_dead_times,
        data_loader=loader,
        max_data_to_test=10000)

    def get_statistics(net,
                       relu_list,
                       neural_list,
                       neural_dead_times,
                       sample_num=10000):
        num_conv = 0  # num of conv layers in the net
        for mod in net.modules():
            if isinstance(mod, torch.nn.modules.conv.Conv2d):
                num_conv += 1

        neural_dead_list = []  #神经元死亡次数的列表
        filter_dead_list = []  #卷积核死亡比率的列表
        FIRE = []
        for i in range(num_conv):
            for relu_key in list(neural_list.keys()):
                if relu_list[
                        i] is relu_key:  # find the neural_list_statistics in layer i+1
                    dead_times = copy.deepcopy(neural_list[relu_key])
                    neural_dead_list += copy.deepcopy(
                        dead_times).flatten().tolist()

                    neural_num = dead_times.shape[1] * dead_times.shape[
                        2]  # neural num for one filter

                    # compute sum(dead_times)/(batch_size*neural_num) as label for each filter
                    dead_times = np.sum(dead_times, axis=(1, 2))
                    FIRE += (dead_times / (neural_num * sample_num)).tolist()

                    # # judge dead filter by neural_dead_times and dead_filter_ratio
                    # dead_times[dead_times < neural_dead_times] = 0
                    # dead_times[dead_times >= neural_dead_times] = 1
                    # dead_times = np.sum(dead_times, axis=(1, 2))  # count the number of dead neural for one filter
                    # dead_times = dead_times / neural_num
                    # filter_dead_list+=dead_times.tolist()
                    break
        active_ratio = 1 - np.array(FIRE)
        active_filter_list = 1 - np.array(filter_dead_list)
        neural_activated_list = (sample_num -
                                 np.array(neural_dead_list)) / sample_num

        return neural_activated_list, active_ratio,  #active_filter_list

    nal_vgg, afl_vgg = get_statistics(vgg16,
                                      relu_list_vgg,
                                      neural_list_vgg,
                                      neural_dead_times=neural_dead_times)
    nal_resnet, afl_resnet = get_statistics(
        resnet56,
        relu_list_resnet,
        neural_list_resnet,
        neural_dead_times=neural_dead_times)
    nal_imagenet, afl_imagenet = get_statistics(vgg16_imagenet,
                                                relu_list_imagenet,
                                                neural_list_imagenet,
                                                sample_num=50000,
                                                neural_dead_times=40000)

    # #cdf_of_dead_neurons
    # plt.figure()
    # plt.hist([nal_vgg,nal_resnet,nal_imagenet],cumulative=True,histtype='step',bins=1000,density=True,)#linewidth=5.0) #cumulative=False为pdf,true为cdf
    # # plt.hist(neural_activated_list,cumulative=True,histtype='bar',bins=20,density=True,rwidth=0.6) #cumulative=False为pdf,true为cdf
    # plt.xlabel('Activation Ratio',fontsize = fontsize)
    # plt.ylabel('Ratio of Neurons',fontsize = fontsize)
    # plt.legend(['VGG-16 on CIFAR-10','ResNet-56 on CIFAR-10','VGG-16 on ImageNet'],loc='upper left',fontsize = fontsize)
    # # plt.savefig('0cdf_of_dead_neurons.jpg')
    # plt.savefig('cdf_of_dead_neurons.eps',format='eps')
    # plt.show()
    #
    # #cdf_of_inactive_filter
    # plt.figure()
    # plt.hist([afl_vgg,afl_resnet,afl_imagenet],cumulative=True,histtype='step',bins=1000,density=True,)#linewidth=5.0) #cumulative=False为pdf,true为cdf
    # # plt.hist(neural_activated_list,cumulative=True,histtype='bar',bins=20,density=True,rwidth=0.6) #cumulative=False为pdf,true为cdf
    # plt.xlabel('Activation Ratio',fontsize = fontsize)
    # plt.ylabel('Ratio of Filters',fontsize = fontsize)
    # plt.legend(['VGG-16 on CIFAR-10','ResNet-56 on CIFAR-10','VGG-16 on ImageNet'],loc='upper left',fontsize = fontsize)
    # # plt.savefig('0cdf_of_inactive_filter.jpg')
    # plt.savefig('cdf_of_inactive_filter.eps',format='eps')
    # plt.show()

    #pdf_of_dead_neurons
    plt.figure()
    hist_list = []
    for nal in [nal_vgg, nal_resnet, nal_imagenet]:
        hist, bins = np.histogram(nal, bins=[0.1 * i for i in range(11)])
        hist_list.append(100 * hist / np.sum(hist))
    x_tick = np.array([5, 15, 25, 35, 45, 55, 65, 75, 85, 95])
    plt.figure()
    plt.bar(x_tick - 2,
            hist_list[0],
            color='coral',
            edgecolor='black',
            label='VGG-16 on CIFAR-10',
            align='center',
            width=2)
    plt.bar(x_tick,
            hist_list[1],
            color='cyan',
            edgecolor='black',
            label='ResNet-56 on CIFAR-10',
            align='center',
            width=2)
    plt.bar(x_tick + 2,
            hist_list[2],
            color='mediumslateblue',
            edgecolor='black',
            label='VGG-16 on ImageNet',
            align='center',
            width=2)
    plt.xticks(x_tick, x_tick, size=tick_fontsize)
    plt.yticks(size=tick_fontsize)
    plt.xlabel('Activation Ratio (%)', fontsize=label_fontsize)
    plt.ylabel('% of Neurons', fontsize=label_fontsize)
    plt.legend(loc='upper right', fontsize=fontsize)
    # plt.savefig('0pdf_of_dead_neurons.jpg')
    plt.savefig('pdf_of_dead_neurons.eps', format='eps', bbox_inches='tight')
    plt.show()

    #pdf_of_inactive_filter
    plt.figure()
    hist_list = []
    for active_ratio in [afl_vgg, afl_resnet, afl_imagenet]:
        hist, bins = np.histogram(active_ratio,
                                  bins=[0.1 * i for i in range(11)])
        hist_list.append(100 * hist / np.sum(hist))
    x_tick = np.array([5, 15, 25, 35, 45, 55, 65, 75, 85, 95])
    plt.figure()
    plt.bar(x_tick - 2,
            hist_list[0],
            color='coral',
            edgecolor='black',
            label='VGG-16 on CIFAR-10',
            align='center',
            width=2)
    plt.bar(x_tick,
            hist_list[1],
            color='cyan',
            edgecolor='black',
            label='ResNet-56 on CIFAR-10',
            align='center',
            width=2)
    plt.bar(x_tick + 2,
            hist_list[2],
            color='mediumslateblue',
            edgecolor='black',
            label='VGG-16 on ImageNet',
            align='center',
            width=2)
    plt.xticks(x_tick, x_tick, size=tick_fontsize)
    plt.yticks(size=tick_fontsize)
    plt.xlabel('Activation Ratio (%)', fontsize=label_fontsize)
    plt.ylabel('% of Filters', fontsize=label_fontsize)
    plt.legend(loc='upper right', fontsize=fontsize)
    # plt.savefig('0pdf_of_inactive_filter.jpg')
    plt.savefig('pdf_of_inactive_filter.eps',
                format='eps',
                bbox_inches='tight')
    plt.show()

    print()
Ejemplo n.º 6
0
def speed_up_pruned_net():
    fontsize = 15

    checkpoint = torch.load(
        '../data/baseline/vgg16_bn_cifar10,accuracy=0.941.tar')
    net_original = storage.restore_net(checkpoint)

    checkpoint = torch.load(
        '../data/baseline/resnet56_cifar10,accuracy=0.93280.tar')
    net_original = resnet_cifar.resnet56()
    net_original.load_state_dict(checkpoint['state_dict'])

    checkpoint = torch.load(
        '../data/model_saved/vgg16bn_cifar10_realdata_regressor6_大幅度/checkpoint/flop=39915982,accuracy=0.93200.tar'
    )

    checkpoint = torch.load(
        '../data/model_saved/resnet56_cifar10_regressor_prunedBaseline2/checkpoint/flop=36145802,accuracy=0.92110.tar'
    )
    net_pruned = storage.restore_net(checkpoint)
    net_pruned.load_state_dict(checkpoint['state_dict'])

    # batch_size=[256,512,1024]
    num_workers = [i for i in range(4, 5)]
    batch_size = [300, 600, 1000, 1600]

    device_list = [torch.device('cuda')]  #
    # device_list=[torch.device('cpu')]
    for num_worker in num_workers:
        time_original = []
        time_pruned = []
        for d in device_list:
            for bs in batch_size:
                net_original.to(d)
                net_pruned.to(d)

                dl = data_loader.create_validation_loader(
                    batch_size=bs,
                    num_workers=num_worker,
                    dataset_name='cifar10')
                start_time = time.time()
                evaluate.evaluate_net(net=net_original,
                                      data_loader=dl,
                                      save_net=False,
                                      device=d)
                end_time = time.time()
                time_original.append(end_time - start_time)
                del dl

                dl = data_loader.create_validation_loader(
                    batch_size=bs,
                    num_workers=num_worker,
                    dataset_name='cifar10')
                start_time = time.time()
                evaluate.evaluate_net(net=net_pruned,
                                      data_loader=dl,
                                      save_net=False,
                                      device=d)
                end_time = time.time()
                time_pruned.append(end_time - start_time)
                del dl

        print('time before pruned:', time_original)
        print('time after pruned:', time_pruned)
        acceleration = np.array(time_original) / np.array(time_pruned)
        baseline = np.ones(shape=len(batch_size))
        x_tick = range(len(baseline))

        plt.figure()
        plt.bar(x_tick, acceleration, color='blue', hatch='//')  #,label='GPU')
        # plt.bar(x_tick[len(batch_size):], acceleration[len(batch_size):], color='grey', hatch='\\', label='CPU')
        # plt.bar(x_tick,baseline,color='red',hatch='*',label='Baseline')
        plt.xticks(x_tick, batch_size, fontsize=fontsize)
        plt.yticks(fontsize=fontsize)
        for x, y in enumerate(list(acceleration)):
            plt.text(x, y + 0.1, '%.2f x' % y, ha='center', fontsize=fontsize)
        plt.ylim([0, np.max(acceleration) + 0.5])
        plt.xlabel('Batch-Size', fontsize=fontsize)
        plt.ylabel('Speed-Up', fontsize=fontsize)
        # plt.legend(loc='upper left')
        plt.savefig('resnet_gpu_speed_up.eps', format='eps')
        # plt.savefig(str(num_worker)+'speed_up.jpg')
        plt.show()
        print()
Ejemplo n.º 7
0
def find_useless_filters_data_version(
    net,
    batch_size,
    percent_of_inactive_filter,
    dataset_name='cifar10',
    use_random_data=False,
    module_list=None,
    neural_list=None,
    dead_or_inactive='inactive',
    neural_dead_times=None,
    filter_FIRE=None,
    # max_data_to_test=10000,
    num_filters_to_prune_at_most=None,
    max_data_to_test=10000,
):
    '''
    use validation set or random generated data to find useless filters in net
    :param net:
    :param batch_size:
    :param dataset_name:
    :param use_random_data:
    :param module_list:
    :param neural_list:
    :param dead_or_inactive:
    param for dead filter
    :param neural_dead_times:
    :param filter_FIRE:
    :param percent_of_inactive_filter:
    :param max_data_to_test: use at most (max_data_to_test) images to calculate the inactive rate
    :param num_filters_to_prune_at_most: list containing the minimum number of filters in each layer
    :return:
    '''
    if dead_or_inactive is 'dead':
        if neural_dead_times is None or filter_FIRE is None:
            print(
                'neural_dead_times and filter_FIRE are required to find dead filters.'
            )
            raise AttributeError
    elif dead_or_inactive is 'inactive':
        if percent_of_inactive_filter is None:
            print(
                'percent_of_inactive_filter is required to find dead filters.')
            raise AttributeError
    else:
        print('unknown type of dead_or_inactive')
        raise AttributeError
    if module_list is None or neural_list is None:
        #calculate dead neural
        if use_random_data is True:
            random_data = generate_random_data.random_normal(
                num=batch_size, dataset_name=dataset_name)
            num_test_images = batch_size
            print('{} generate random data.'.format(datetime.now()))
            # module_list, neural_list = check_conv_alive_layerwise(net=net,neural_dead_times=batch_size,batch_size=batch_size)
            module_list, neural_list = check_ReLU_alive(
                net=net, neural_dead_times=batch_size, data=random_data)
            del random_data
        else:
            if dataset_name is 'imagenet':
                train_set_size = conf.imagenet['train_set_size']
            elif dataset_name is 'cifar10':
                train_set_size = conf.cifar10['train_set_size']
            elif dataset_name is 'cifar100':
                train_set_size = conf.cifar100['train_set_size']
            elif dataset_name is 'tiny_imagenet':
                train_set_size = conf.tiny_imagenet['train_set_size']
            train_loader = data_loader.create_validation_loader(
                batch_size=batch_size,
                num_workers=8,
                dataset_name=dataset_name + '_trainset',
                shuffle=True,
            )

            num_test_images = min(
                train_set_size,
                math.ceil(max_data_to_test / batch_size) * batch_size)
            if neural_dead_times is None and dead_or_inactive is 'inactive':
                neural_dead_times = 0.8 * num_test_images

            if isinstance(net, torch.nn.DataParallel):
                net_test = copy.deepcopy(net._modules['module'])  #use one gpu
            else:
                net_test = copy.deepcopy(net)
            module_list, neural_list = check_ReLU_alive(
                net=net_test,
                data_loader=train_loader,
                neural_dead_times=neural_dead_times,
                max_data_to_test=max_data_to_test)
            del net_test
            del train_loader
    num_conv = 0  # num of conv layers in the net
    filter_num = []
    for name, mod in net.named_modules():
        if isinstance(mod, torch.nn.Conv2d) and 'downsample' not in name:
            num_conv += 1
            filter_num.append(mod.out_channels)

    useless_filter_index = [[] for i in range(num_conv)]
    if dead_or_inactive is 'inactive':
        # filter_index=[]                                 #index of the filter in its layer
        # filter_layer=[]                                 #which layer the filter is in
        FIRE = []
    for i in range(len(module_list)
                   ):  #the number of relu after conv  range(len(module_list)):
        for module_key in list(neural_list.keys()):
            if module_list[
                    i] is module_key:  # find the neural_list_statistics in layer i+1
                dead_times = copy.deepcopy(neural_list[module_key])
                neural_num = dead_times.shape[1] * dead_times.shape[
                    2]  # neural num for one filter

                if dead_or_inactive is 'dead':
                    print(
                        'warning!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! may be wrong!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
                    )
                    # judge dead filter by neural_dead_times and dead_filter_ratio
                    dead_times[dead_times < neural_dead_times] = 0
                    dead_times[dead_times >= neural_dead_times] = 1
                    dead_times = np.sum(dead_times, axis=(
                        1,
                        2))  # count the number of dead neural for one filter

                    df_num = np.where(
                        dead_times >= neural_num *
                        filter_FIRE)[0].shape[0]  #number of dead filters
                    df_index = np.argsort(-dead_times)[:df_num].tolist(
                    )  #dead filters' indices. sorted by the times that they died.
                    useless_filter_index.append(df_index)
                elif dead_or_inactive is 'inactive':
                    # compute sum(dead_times)/(batch_size*neural_num) as label for each filter
                    dead_times = np.sum(dead_times, axis=(1, 2))
                    # if use_random_data is True:
                    #     FIRE += (dead_times / (neural_num * batch_size)).tolist()
                    # else:
                    FIRE += (dead_times /
                             (neural_num * num_test_images)).tolist()
                    # filter_layer += [i for j in range(dead_times.shape[0])]
                    # filter_index+=[j for j in range(dead_times.shape[0])]

    if dead_or_inactive is 'dead':
        return useless_filter_index, module_list, neural_list
    elif dead_or_inactive is 'inactive':
        useless_filter_index = sort_inactive_filters(
            net, percent_of_inactive_filter, FIRE,
            num_filters_to_prune_at_most)
        # cutoff_rank_increase=-1
        # delta=0
        # while cutoff_rank_increase!=delta:
        #     cutoff_rank_increase = delta
        #     delta = 0
        #     cutoff_rank=int(percent_of_inactive_filter*len(FIRE))+cutoff_rank_increase
        #     inactive_rank=np.argsort(-np.array(FIRE))[:cutoff_rank]                #arg for top (percent_of_inactive_filter)*100% of inactive filters
        #     inactive_filter_index=np.array(filter_index)[inactive_rank]                     #index of top (percent_of_inactive_filter)*100% inactive filters
        #     inactive_filter_layer=np.array(filter_layer)[inactive_rank]
        #     for i in range(num_conv):
        #         index = inactive_filter_index[np.where(inactive_filter_layer == i)]     #index of inactive filters in layer i
        #         if num_filters_to_prune_at_most is not None:
        #             if len(index)>num_filters_to_prune_at_most[i]:
        #                 delta+=len(index)-num_filters_to_prune_at_most[i]           #number of inactive filters excluded because of the restriction
        #                 index=index[:num_filters_to_prune_at_most[i]]
        #         useless_filter_index[i]=index

        return useless_filter_index, module_list, neural_list, FIRE
Ejemplo n.º 8
0
# def conversion(dataset_name,net_name,checkpoint_path='',checkpoint=None):

#转化之前的cifar上的resnet
#     checkpoint=torch.load(checkpoint_path)
#     # net=checkpoint.pop('net')
#     net=resnet_cifar.resnet32()
#     checkpoint['state_dict']['fc.weight']=checkpoint['state_dict'].pop('linear.weight')
#     checkpoint['state_dict']['fc.bias']=checkpoint['state_dict'].pop('linear.bias')
#
#     net.load_state_dict(checkpoint['state_dict'])
#     checkpoint.update(get_net_information(net,dataset_name,net_name))
#     torch.save(checkpoint,checkpoint_path)

if __name__ == "__main__":
    # conversion(checkpoint_path='../data/baseline/resnet32_cifar10,accuracy=0.92380.tar',net_name='resnet56',dataset_name='cifar10')

    checkpoint = torch.load(
        '/home/victorfang/model_pytorch/data/baseline/resnet56_cifar10,accuracy=0.94230.tar'
    )

    net = resnet_cifar.resnet56()
    # net.load_state_dict(checkpoint['state_dict'])
    c = get_net_information(net=net,
                            dataset_name='cifar10',
                            net_name='resnet56')
    net = restore_net(checkpoint, True)
    from framework import evaluate, data_loader
    evaluate.evaluate_net(
        net, data_loader.create_validation_loader(512, 4, 'cifar10'), False)
    # c=get_net_information(net=net,dataset_name=dataset_name,net_name='resnet50')
Ejemplo n.º 9
0
#     c_new = {'highest_accuracy': c_original['highest_accuracy'],
#              'state_dict': new_state_dict}
#
#     torch.save(c_new, path)

if __name__ == "__main__":
    import torch
    from network import storage
    from framework import evaluate, data_loader

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(
        '/home/victorfang/model_pytorch/data/baseline/resnet56_cifar100_0.71580.tar'
    )
    # c_sample=torch.load('/home/victorfang/model_pytorch/data/baseline/vgg16_bn_cifar10,accuracy=0.941.tar')
    net = resnet56(num_classes=200)
    net.load_state_dict(checkpoint['state_dict'])
    net.to(device)

    # checkpoint.update(storage.get_net_information(net=net,dataset_name='tiny_imagenet',net_name='resnet18'))
    # checkpoint.pop('net')
    # torch.save(checkpoint,'/home/victorfang/model_pytorch/data/baseline/resnet18_tinyimagenet_v2_0.72990.tar')
    # net=storage.restore_net(checkpoint=checkpoint,pretrained=True)
    # net=nn.DataParallel(net)
    evaluate.evaluate_net(net=net,
                          data_loader=data_loader.create_validation_loader(
                              512, 8, 'cifar100'),
                          save_net=False,
                          dataset_name='cifar100')

    print()
Ejemplo n.º 10
0
def train(
        net,
        net_name,
        exp_name='',
        dataset_name='imagenet',
        train_loader=None,
        validation_loader=None,
        learning_rate=conf.learning_rate,
        num_epochs=conf.num_epochs,
        batch_size=conf.batch_size,
        evaluate_step=conf.evaluate_step,
        load_net=True,
        test_net=False,
        root_path=conf.root_path,
        checkpoint_path=None,
        momentum=conf.momentum,
        num_workers=conf.num_workers,
        learning_rate_decay=False,
        learning_rate_decay_factor=conf.learning_rate_decay_factor,
        learning_rate_decay_epoch=conf.learning_rate_decay_epoch,
        weight_decay=conf.weight_decay,
        target_accuracy=1.0,
        optimizer=optim.SGD,
        top_acc=1,
        criterion=nn.CrossEntropyLoss(),  # 损失函数默为交叉熵,多用于多分类问题
        no_grad=[],
        scheduler_name='MultiStepLR',
        eta_min=0,
        #todo:tmp!!!
        data_parallel=False):
    '''

    :param net: net to be trained
    :param net_name: name of the net
    :param exp_name: name of the experiment
    :param dataset_name: name of the dataset
    :param train_loader: data_loader for training. If not provided, a data_loader will be created based on dataset_name
    :param validation_loader: data_loader for validation. If not provided, a data_loader will be created based on dataset_name
    :param learning_rate: initial learning rate
    :param learning_rate_decay: boolean, if true, the learning rate will decay based on the params provided.
    :param learning_rate_decay_factor: float. learning_rate*=learning_rate_decay_factor, every time it decay.
    :param learning_rate_decay_epoch: list[int], the specific epoch that the learning rate will decay.
    :param num_epochs: max number of epochs for training
    :param batch_size:
    :param evaluate_step: how often will the net be tested on validation set. At least one test every epoch is guaranteed
    :param load_net: boolean, whether loading net from previous checkpoint. The newest checkpoint will be selected.
    :param test_net:boolean, if true, the net will be tested before training.
    :param root_path:
    :param checkpoint_path:
    :param momentum:
    :param num_workers:
    :param weight_decay:
    :param target_accuracy:float, the training will stop once the net reached target accuracy
    :param optimizer:
    :param top_acc: can be 1 or 5
    :param criterion: loss function
    :param no_grad: list containing names of the modules that do not need to be trained
    :param scheduler_name
    :param eta_min: for CosineAnnealingLR
    :return:
    '''
    success = True  #if the trained net reaches target accuracy
    # gpu or not
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('using: ', end='')
    if torch.cuda.is_available():
        print(torch.cuda.device_count(), ' * ', end='')
        print(torch.cuda.get_device_name(torch.cuda.current_device()))
    else:
        print(device)

    #prepare the data
    if dataset_name is 'imagenet':
        mean = conf.imagenet['mean']
        std = conf.imagenet['std']
        train_set_path = conf.imagenet['train_set_path']
        train_set_size = conf.imagenet['train_set_size']
        validation_set_path = conf.imagenet['validation_set_path']
        default_image_size = conf.imagenet['default_image_size']
    elif dataset_name is 'cifar10':
        train_set_size = conf.cifar10['train_set_size']
        mean = conf.cifar10['mean']
        std = conf.cifar10['std']
        train_set_path = conf.cifar10['dataset_path']
        validation_set_path = conf.cifar10['dataset_path']
        default_image_size = conf.cifar10['default_image_size']
    elif dataset_name is 'tiny_imagenet':
        train_set_size = conf.tiny_imagenet['train_set_size']
        mean = conf.tiny_imagenet['mean']
        std = conf.tiny_imagenet['std']
        train_set_path = conf.tiny_imagenet['train_set_path']
        validation_set_path = conf.tiny_imagenet['validation_set_path']
        default_image_size = conf.tiny_imagenet['default_image_size']
    elif dataset_name is 'cifar100':
        train_set_size = conf.cifar100['train_set_size']
        mean = conf.cifar100['mean']
        std = conf.cifar100['std']
        train_set_path = conf.cifar100['dataset_path']
        validation_set_path = conf.cifar100['dataset_path']
        default_image_size = conf.cifar100['default_image_size']
    if train_loader is None:
        train_loader = data_loader.create_train_loader(
            dataset_path=train_set_path,
            default_image_size=default_image_size,
            mean=mean,
            std=std,
            batch_size=batch_size,
            num_workers=num_workers,
            dataset_name=dataset_name)
    if validation_loader is None:
        validation_loader = data_loader.create_validation_loader(
            dataset_path=validation_set_path,
            default_image_size=default_image_size,
            mean=mean,
            std=std,
            batch_size=batch_size,
            num_workers=num_workers,
            dataset_name=dataset_name)

    if checkpoint_path is None:
        checkpoint_path = os.path.join(root_path, 'model_saved', exp_name,
                                       'checkpoint')
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path, exist_ok=True)

    #get the latest checkpoint
    lists = os.listdir(checkpoint_path)
    file_new = checkpoint_path
    if len(lists) > 0:
        lists.sort(key=lambda fn: os.path.getmtime(checkpoint_path + "/" + fn)
                   )  # 按时间排序
        file_new = os.path.join(checkpoint_path,
                                lists[-1])  # 获取最新的文件保存到file_new

    sample_num = 0
    if os.path.isfile(file_new):
        if load_net:
            checkpoint = torch.load(file_new)
            print('{} load net from previous checkpoint:{}'.format(
                datetime.now(), file_new))
            # net.load_state_dict(checkpoint['state_dict'])
            net = storage.restore_net(checkpoint,
                                      pretrained=True,
                                      data_parallel=data_parallel)
            sample_num = checkpoint['sample_num']

    if test_net:
        print('{} test the net'.format(
            datetime.now()))  #no previous checkpoint
        net_test = copy.deepcopy(net)
        accuracy = evaluate.evaluate_net(net_test,
                                         validation_loader,
                                         save_net=True,
                                         checkpoint_path=checkpoint_path,
                                         sample_num=sample_num,
                                         target_accuracy=target_accuracy,
                                         dataset_name=dataset_name,
                                         top_acc=top_acc,
                                         net_name=net_name,
                                         exp_name=exp_name)
        del net_test

        if accuracy >= target_accuracy:
            print('{} net reached target accuracy.'.format(datetime.now()))
            return success

    #ensure the net will be evaluated despite the inappropriate evaluate_step
    if evaluate_step > math.ceil(train_set_size / batch_size) - 1:
        evaluate_step = math.ceil(train_set_size / batch_size) - 1

    optimizer = prepare_optimizer(net, optimizer, no_grad, momentum,
                                  learning_rate, weight_decay)
    if learning_rate_decay:
        if scheduler_name == 'MultiStepLR':
            scheduler = lr_scheduler.MultiStepLR(
                optimizer,
                milestones=learning_rate_decay_epoch,
                gamma=learning_rate_decay_factor,
                last_epoch=ceil(sample_num / train_set_size))
        elif scheduler_name == 'CosineAnnealingLR':
            scheduler = lr_scheduler.CosineAnnealingLR(
                optimizer,
                num_epochs,
                eta_min=eta_min,
                last_epoch=ceil(sample_num / train_set_size))
    print("{} Start training ".format(datetime.now()) + net_name + "...")
    for epoch in range(math.floor(sample_num / train_set_size), num_epochs):
        print("{} Epoch number: {}".format(datetime.now(), epoch + 1))
        net.train()
        # one epoch for one loop
        for step, data in enumerate(train_loader, 0):
            if sample_num / train_set_size == epoch + 1:  #one epoch of training finished
                net_test = copy.deepcopy(net)
                accuracy = evaluate.evaluate_net(
                    net_test,
                    validation_loader,
                    save_net=True,
                    checkpoint_path=checkpoint_path,
                    sample_num=sample_num,
                    target_accuracy=target_accuracy,
                    dataset_name=dataset_name,
                    top_acc=top_acc,
                    net_name=net_name,
                    exp_name=exp_name)
                del net_test
                if accuracy >= target_accuracy:
                    print('{} net reached target accuracy.'.format(
                        datetime.now()))
                    return success
                break

            # 准备数据
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            sample_num += int(images.shape[0])

            # if learning_rate_decay:
            #     exponential_decay_learning_rate(optimizer=optimizer,
            #                                     sample_num=sample_num,
            #                                     learning_rate_decay_factor=learning_rate_decay_factor,
            #                                     train_set_size=train_set_size,
            #                                     learning_rate_decay_epoch=learning_rate_decay_epoch,
            #                                     batch_size=batch_size)

            optimizer.zero_grad()
            # forward + backward
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if step % 60 == 0:
                print('{} loss is {}'.format(datetime.now(), float(loss.data)))

            if step % evaluate_step == 0 and step != 0:
                net_test = copy.deepcopy(net)
                accuracy = evaluate.evaluate_net(
                    net_test,
                    validation_loader,
                    save_net=True,
                    checkpoint_path=checkpoint_path,
                    sample_num=sample_num,
                    target_accuracy=target_accuracy,
                    dataset_name=dataset_name,
                    top_acc=top_acc,
                    net_name=net_name,
                    exp_name=exp_name)
                del net_test
                if accuracy >= target_accuracy:
                    print('{} net reached target accuracy.'.format(
                        datetime.now()))
                    return success
                accuracy = float(accuracy)
                print('{} continue training'.format(datetime.now()))
        if learning_rate_decay:
            scheduler.step()
            print(optimizer.state_dict()['param_groups'][0]['lr'])

    print("{} Training finished. Saving net...".format(datetime.now()))
    net_test = copy.deepcopy(net)
    flop_num = measure_flops.measure_model(net=net_test,
                                           dataset_name=dataset_name,
                                           print_flop=False)
    accuracy = evaluate.evaluate_net(net_test,
                                     validation_loader,
                                     save_net=True,
                                     checkpoint_path=checkpoint_path,
                                     sample_num=sample_num,
                                     target_accuracy=target_accuracy,
                                     dataset_name=dataset_name,
                                     top_acc=top_acc,
                                     net_name=net_name,
                                     exp_name=exp_name)
    accuracy = float(accuracy)
    checkpoint = {
        'highest_accuracy': accuracy,
        'state_dict': net.state_dict(),
        'sample_num': sample_num,
        'flop_num': flop_num
    }
    checkpoint.update(storage.get_net_information(net, dataset_name, net_name))
    torch.save(
        checkpoint,
        '%s/flop=%d,accuracy=%.5f.tar' % (checkpoint_path, flop_num, accuracy))
    print("{} net saved at sample num = {}".format(datetime.now(), sample_num))
    return not success