示例#1
0
def compare_train():
    if args.resume == 'none':
        save_path = os.path.join(tmp_dir, args.arch+'_compare.pth')
        model_state_pre_best = get_model(args.ref_model, device=device)
        best_acc = 0  # best test accuracy
        cfg = model_state_pre_best['cfg']
        net_pruned = eval(args.arch)(args.num_class, cfg=cfg)
        net_pruned.to(device)
        optimizer = optim.SGD(net_pruned.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    else:

        model_state_pre_best = get_model(args.resume, device=device)
        best_acc = model_state_pre_best['lasted_best_prec1']
        save_path = os.path.join(tmp_dir, args.resume)
        cfg = model_state_pre_best['cfg']
        net_pruned = eval(args.arch)(args.num_class, cfg=cfg)
        net_pruned.load_state_dict(model_state_pre_best['state_dict'])
        net_pruned.to(device)
        optimizer = optim.SGD(net_pruned.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        optimizer.load_state_dict(model_state_pre_best['optimizer'])

    end_epoch = 100
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
    criterion = nn.CrossEntropyLoss()
    for epoch in range(start_epoch + 1, end_epoch):
        net_pruned.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            with torch.cuda.device(device):
                inputs = inputs.to(device)
                targets = targets.to(device)
                optimizer.zero_grad()
                outputs = net_pruned(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                # print(batch_idx,len(trainloader),
                #              ' Loss: %.3f | Acc: %.3f%% (%d/%d)'
                #              % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        net_pruned.eval()
        num_iterations = len(testloader)
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net_pruned(inputs)

                prec1, prec5 = utils.accuracy(outputs, targets, topk=(1, 5))
                top1.update(prec1[0], inputs.size(0))
                top5.update(prec5[0], inputs.size(0))

            print(
                'Epoch[{0}]({1}/{2}): '
                'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
                    epoch, batch_idx, num_iterations, top1=top1, top5=top5))

        if (top1.avg.item() > best_acc):
            print('当前测试精度%f大于当前层剪枝周期的最号精度%f,保存模型%s)' % (
            top1.avg.item(), best_acc, save_path))

            best_acc = top1.avg.item()

            model_state = {
                'state_dict': net_pruned.state_dict(),
                'lasted_best_prec1': round(best_acc, 2),
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'cfg': cfg,
            }
            torch.save(model_state, save_path)
示例#2
0
def prune_train():

    print('checkpoint %s exists,train from checkpoint' % args.resume)
    save_path = os.path.join(tmp_dir, args.resume)
    model_state = get_model(args.resume, device=device)

    conv_list = get_conv_name_list(model_state)
    conv_list_asc = conv_list.copy()
    conv_list.reverse()

    try:
        start_epoch = model_state['epoch']
    except KeyError:
        start_epoch = 0
    end_epoch = start_epoch + 40

    cudnn.benchmark = True
    print('\nEpoch: %d' % start_epoch)

    criterion = nn.CrossEntropyLoss()
    for conv_layername in conv_list:
        model_state_pre_best = get_model(args.resume, device=device)

        try:
            pruned_layers = model_state['pruned_layer']
        except KeyError:
            pruned_layers = dict()
        try:
            baseline_best_prec1 = model_state_pre_best['baseline_best_prec1']
        except KeyError:
            baseline_best_prec1 = model_state_pre_best['best_prec1']
            model_state_pre_best['baseline_best_prec1'] = baseline_best_prec1

        try:
            previous_lay_qualified_top1_mean = model_state_pre_best[
                'previous_lay_qualified_top1_mean']
        except KeyError:
            model_state_pre_best[
                'previous_lay_qualified_top1_mean'] = baseline_best_prec1

        if not 'lasted_best_prec1' in model_state_pre_best.keys():
            model_state_pre_best['lasted_best_prec1'] = model_state_pre_best[
                'baseline_best_prec1']

        # 剪枝某一层时初始化当前层最好精度为0
        lasted_best_prec1 = 0

        try:
            cfg = model_state_pre_best['cfg']
        except KeyError:
            cfg = None
        net_pre_best = eval(args.arch)(args.num_class, cfg=cfg)
        net_pre_best.load_state_dict(model_state_pre_best['state_dict'])
        optimizer = optim.SGD(net_pre_best.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=5e-4)
        optimizer.load_state_dict(model_state_pre_best['optimizer'])
        print("获取上轮剪枝最佳模型,剪枝清单为%s,测试后准确率为%f" %
              (pruned_layers, lasted_best_prec1))
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()

        #如果已经剪枝则跳过下面的逻辑进行下一层的剪枝
        if conv_layername in pruned_layers.keys():
            continue
        #获取当前卷积层在所有卷积层总顺序派的索引号
        conv_prune_dict_cleaned = {name: -1 for name in conv_list_asc}
        conv_idx = get_conv_idx_by_name(conv_layername, conv_list_asc)
        print('%s在所有卷积层总顺序派的索引号为%d' % (conv_layername, conv_idx))
        print("处理%s,剪枝测试中" % conv_layername)
        # conv_prune_dict_cleaned[conv_layername] = {i if i%10==0 else 0 for i in range(512)}
        conv_prune_dict_cleaned[
            conv_layername], current_lay_qualified_top1_mean = search_by_conv_idx(
                model_state_pre_best, net_pre_best.origin_cfg, conv_layername,
                conv_idx, len(conv_list), testloader, args)

        print("剪枝字典为%s" % str(conv_prune_dict_cleaned))
        net_current_pruned, pruned_cfg = get_net_by_prune_dict(
            net_pre_best, args, conv_prune_dict_cleaned)
        net_current_pruned = net_current_pruned.to(device)
        pruned_layers[conv_layername] = conv_prune_dict_cleaned[conv_layername]
        total_drop_list = [
            list(v) if isinstance(v, list) else -1
            for k, v in conv_prune_dict_cleaned.items()
        ]

        total_drop_list_resnet34 = []

        try:
            for k, v in model_state['pruned_layer'].items():
                total_drop_list_resnet34.append(v)
        except KeyError:
            print('pruned_layer为空')

        for i in range(len(conv_list) - len(total_drop_list_resnet34)):
            total_drop_list_resnet34.append(-1)
        total_drop_list_resnet34.reverse()
        for epoch in range(start_epoch + 1, end_epoch):
            net_current_pruned.train()
            train_loss = 0
            correct = 0
            total = 0
            for batch_idx, (inputs, targets) in enumerate(trainloader):
                with torch.cuda.device(device):
                    inputs = inputs.to(device)
                    targets = targets.to(device)
                    optimizer.zero_grad()
                    outputs = net_current_pruned(inputs,
                                                 total_drop_list_resnet34)
                    loss = criterion(outputs, targets)
                    loss.backward()
                    optimizer.step()

                    train_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total += targets.size(0)
                    correct += predicted.eq(targets).sum().item()

                    # print(batch_idx,len(trainloader),
                    #              ' Loss: %.3f | Acc: %.3f%% (%d/%d)'
                    #              % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
            top1 = utils.AverageMeter()
            top5 = utils.AverageMeter()
            net_current_pruned.eval()
            num_iterations = len(testloader)
            with torch.no_grad():
                for batch_idx, (inputs, targets) in enumerate(testloader):
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = net_current_pruned(inputs,
                                                 total_drop_list_resnet34)

                    prec1, prec5 = utils.accuracy(outputs,
                                                  targets,
                                                  topk=(1, 5))
                    top1.update(prec1[0], inputs.size(0))
                    top5.update(prec5[0], inputs.size(0))

                print('Epoch[{0}]({1}/{2}): '
                      'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
                          epoch,
                          batch_idx,
                          num_iterations,
                          top1=top1,
                          top5=top5))

            if (top1.avg.item() > lasted_best_prec1):
                print(
                    '当前测试精度%f大于当前层剪枝周期的最号精度%f,baseline最好精度为%f,当前层剪枝精度平均%f,保存模型%s)'
                    % (top1.avg.item(), lasted_best_prec1, baseline_best_prec1,
                       current_lay_qualified_top1_mean, save_path))

                lasted_best_prec1 = top1.avg.item()

                model_state = {
                    'state_dict': net_current_pruned.state_dict(),
                    'baseline_best_prec1': baseline_best_prec1,
                    'lasted_best_prec1': round(lasted_best_prec1, 2),
                    'epoch': epoch,
                    'previous_lay_qualified_top1_mean':
                    current_lay_qualified_top1_mean,
                    'optimizer': optimizer.state_dict(),
                    'pruned_layer': pruned_layers,
                    'cfg': pruned_cfg,
                }
                torch.save(model_state, save_path)
                best_model_path = os.path.join(
                    tmp_dir,
                    str(lasted_best_prec1) + '_' + conv_layername + '.pth')
                torch.save(model_state, best_model_path)
示例#3
0
    m.layer_mask(cov_id + 1,
                 resume=args.resume_mask,
                 param_per_cov=param_per_cov_dic[args.arch],
                 arch=args.arch)

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=lr_decay_step,
                                               gamma=0.1)

    if cov_id == 0:

        pruned_checkpoint = get_model(args.resume, device=device)
        from collections import OrderedDict

        new_state_dict = OrderedDict()
        if args.arch == 'resnet_50':
            tmp_ckpt = pruned_checkpoint
        else:
            tmp_ckpt = pruned_checkpoint['state_dict']

        if len(args.gpu) > 1:
            for k, v in tmp_ckpt.items():
                new_state_dict['module.' + k.replace('module.', '')] = v
        else:
            for k, v in tmp_ckpt.items():
                new_state_dict[k.replace('module.', '')] = v
示例#4
0
def train_baseline():

    if args.resume != "none":
        print('checkpoint %s exists,train from checkpoint' % args.resume)
        save_path = os.path.join(baseline_dir, args.resume)
        # model_state = torch.load(args.resume, map_location=device)
        model_state = get_model(args.resume, device=device)
        cfg = model_state['cfg']
        net = eval(args.arch)(args.num_class)
        current_model_best_acc = model_state['best_prec1']
        net.load_state_dict(model_state['state_dict'])
        optimizer = optim.SGD(net.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=5e-4)
        optimizer.load_state_dict(model_state['optimizer'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()
        try:
            start_epoch = model_state['epoch']
        except KeyError:
            start_epoch = 0
        end_epoch = start_epoch + 280
    else:
        save_path = os.path.join(baseline_dir,
                                 args.arch + '_' + args.dataset + '.pth')
        current_model_best_acc = 0
        net = eval(args.arch)(args.num_class)
        cfg = net.cfg
        optimizer = optim.SGD(net.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=5e-4)
        start_epoch = 0
        end_epoch = 100

    net = net.to(device)
    cudnn.benchmark = True
    print('\nEpoch: %d' % start_epoch)
    fake_drop_out_list = [-1 for i in range(31)]

    criterion = nn.CrossEntropyLoss()
    for epoch in range(start_epoch + 1, end_epoch):
        net.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            with torch.cuda.device(device):
                inputs = inputs.to(device)
                targets = targets.to(device)
                optimizer.zero_grad()
                outputs = net(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                print(
                    batch_idx, len(trainloader),
                    ' Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                    (train_loss /
                     (batch_idx + 1), 100. * correct / total, correct, total))
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        net.eval()
        num_iterations = len(testloader)
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)

                prec1, prec5 = utils.accuracy(outputs, targets, topk=(1, 5))
                top1.update(prec1[0], inputs.size(0))
                top5.update(prec5[0], inputs.size(0))

            print('Epoch[{0}]({1}/{2}): '
                  'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
                      epoch, batch_idx, num_iterations, top1=top1, top5=top5))

        if (top1.avg.item() > current_model_best_acc):
            current_model_best_acc = top1.avg.item()
            model_state = {
                'state_dict': net.state_dict(),
                'best_prec1': current_model_best_acc,
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'cfg': cfg,
            }

            torch.save(model_state, save_path)

    print("=>Best accuracy {:.3f}".format(model_state['best_prec1']))
示例#5
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.seed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.seed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log("Norm Pruning Rate: {}".format(args.rate_norm), log)
    print_log("Distance Pruning Rate: {}".format(args.rate_dist), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("use pretrain: {}".format(args.use_pretrain), log)
    print_log("Pretrain path: {}".format(args.pretrain_path), log)
    print_log("Dist type: {}".format(args.dist_type), log)
    print_log("Pre cfg: {}".format(args.use_precfg), log)

    if args.dataset == 'cifar10':
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            args.data_path,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Pad(4),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_path,
                             train=False,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                      (0.2023, 0.1994, 0.2010))
                             ])),
            batch_size=args.test_batch_size,
            shuffle=False,
            **kwargs)
    else:
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            args.data_path,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Pad(4),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(args.data_path,
                              train=False,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                                  transforms.Normalize(
                                      (0.4914, 0.4822, 0.4465),
                                      (0.2023, 0.1994, 0.2010))
                              ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

    print_log("=> creating model '{}'".format(args.arch), log)
    model = eval(args.arch)(args.num_class)
    print_log("=> network :\n {}".format(model), log)

    if args.cuda:
        model.cuda()

    if args.use_pretrain:
        pretrain = get_model(args.resume, device='cuda')
        if args.use_state_dict:
            model.load_state_dict(pretrain['state_dict'])
        else:
            model = pretrain['state_dict']

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if args.resume:
        checkpoint = get_model(args.resume, device='cuda')
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model = eval(args.arch)(args.num_class, cfg=checkpoint['cfg'])
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(
            args.resume, checkpoint['epoch'], best_prec1))
        if args.cuda:
            model = model.cuda()

    if args.evaluate:
        time1 = time.time()
        test(test_loader, model, log)
        time2 = time.time()
        print('function took %0.3f ms' % ((time2 - time1) * 1000.0))
        return

    m = Mask(model)
    m.init_length()
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("remaining ratio of pruning : Norm is %f" % args.rate_norm)
    print("reducing ratio of pruning : Distance is %f" % args.rate_dist)
    print("total remaining ratio is %f" % (args.rate_norm - args.rate_dist))

    val_acc_1 = test(test_loader, model, log)

    print(" accu before is: %.3f %%" % val_acc_1)

    m.model = model

    m.init_mask(args.rate_norm, args.rate_dist, args.dist_type)
    #    m.if_zero()
    m.do_mask()
    m.do_similar_mask()
    model = m.model
    #    m.if_zero()
    if args.cuda:
        model = model.cuda()
    val_acc_2 = test(test_loader, model, log)
    print(" accu after is: %s %%" % val_acc_2)

    best_prec1 = 0.
    for epoch in range(args.start_epoch, args.epochs):
        if epoch in [args.epochs * 0.5, args.epochs * 0.75]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
        train(train_loader, model, optimizer, epoch, log)
        prec1 = test(test_loader, model, log)

        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:
            m.model = model
            m.if_zero()
            m.init_mask(args.rate_norm, args.rate_dist, args.dist_type)
            m.do_mask()
            m.do_similar_mask()
            # small_filter_index.append(m.filter_small_index)
            # large_filter_index.append(m.filter_large_index)
            # save_obj(small_filter_index, 'small_filter_index_2')
            # save_obj(large_filter_index, 'large_filter_index_2')
            m.if_zero()
            model = m.model
            if args.cuda:
                model = model.cuda()
            val_acc_2 = test(test_loader, model, log)
        is_best = val_acc_2 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
                'cfg': model.cfg
            },
            is_best,
            filepath=args.save_path)
示例#6
0
    help='how many batches to wait before logging training status')
parser.add_argument('--num_class', type=int, default='100')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader, test_loader = get_loaders(args.dataset, args.data_dir,
                                        args.train_batch_size,
                                        args.eval_batch_size, args.arch)

if args.refine:
    checkpoint = get_model(args.refine, device='cuda')
    model = vgg_16_bn(cfg=checkpoint['cfg'], num_class=args.num_class)
    model.cuda()
    model.load_state_dict(checkpoint['state_dict'])
    args.start_epoch = checkpoint['epoch']
    try:
        best_prec1 = checkpoint['best_prec1']
    except KeyError:
        best_prec1 = checkpoint['lasted_best_prec1']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=5e-4)
    # optimizer.load_state_dict(checkpoint['optimizer'])
else:
示例#7
0
def train_baseline():

    if args.resume != "none":
        print('checkpoint %s exists,train from checkpoint' % args.resume)
        save_path = os.path.join(baseline_dir, args.resume)
        # model_state = torch.load(args.resume, map_location=device)
        model_state = get_model(args.resume, device=device)
        net = model_state['net']
        current_model_best_acc = model_state['best_prec1']

        # optimizer.load_state_dict(model_state['optimizer'])
        # for state in optimizer.state.values():
        #     for k, v in state.items():
        #         if torch.is_tensor(v):
        #             state[k] = v.cuda()
        try:
            start_epoch = model_state['epoch']
        except KeyError:
            start_epoch = 0
        end_epoch = start_epoch + 130
    else:
        save_path = os.path.join(baseline_dir,
                                 args.arch + '_' + args.dataset + '.pth')
        current_model_best_acc = 0
        net = eval(args.arch)(args.num_class)
        # net = VGG(args.num_class)
        # optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        start_epoch = 0
        end_epoch = 2

    net = net.to(device)
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    cudnn.benchmark = True
    print('\nEpoch: %d' % start_epoch)
    count_dict = {}
    criterion = nn.CrossEntropyLoss()
    filter_count = Counter()
    for epoch in range(start_epoch + 1, end_epoch):
        net.train()
        train_loss = 0
        correct = 0
        total = 0

        temp_tensor = torch.zeros(512, 512)
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            with torch.cuda.device(device):
                inputs = inputs.to(device)
                targets = targets.to(device)
                optimizer.zero_grad()
                outputs = net(inputs)
                # np.linalg.norm(net.layer0_conv.weight.grad.view(64, -1).cpu().numpy(),axis=0)

                loss = criterion(outputs, targets)
                loss.backward()

                count_dict = get_count_dict(net, args.compress_rate,
                                            count_dict)
                # print(count_dict)
                if (batch_idx in [100]):
                    model_state1 = {
                        'net': net,
                        'best_prec1': current_model_best_acc,
                        'epoch': epoch,
                        'counter_dict': count_dict
                    }
                    #     net.load_state_dict(model_state)
                    # net = model_state['net']
                    #
                    # total = sum([param.nelement() for param in model_state['net'].parameters()])
                    # print('  + Number of params: %.2fM' % (total / 1e6))  # 每一百万为一个单位
                    # input = torch.randn(100, 3, 32, 32)
                    # flops, params = profile(net, inputs=(input.cuda(),))
                    # print('flops %f', flops/1e6)

                    torch.save(model_state1, save_path)
                    print('saved at batch' + str(batch_idx))
                # grad_dist_matrix = get_dist_matrix(net, 4)
                # likely_ndarray = get_likely_point(grad_dist_matrix,(1-args.compress_rate)*100)
                # filter_count += get_filter_index(likely_ndarray)

                optimizer.step()

                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                # print(batch_idx,len(trainloader),
                #              ' Loss: %.3f | Acc: %.3f%% (%d/%d)'
                #              % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        net.eval()
        num_iterations = len(testloader)
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)

                prec1, prec5 = utils.accuracy(outputs, targets, topk=(1, 5))
                top1.update(prec1[0], inputs.size(0))
                top5.update(prec5[0], inputs.size(0))

            print('Epoch[{0}]({1}/{2}): '
                  'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
                      epoch, batch_idx, num_iterations, top1=top1, top5=top5))

        if (top1.avg.item() > current_model_best_acc):
            current_model_best_acc = top1.avg.item()
            model_state = {
                'net': net,
                'best_prec1': current_model_best_acc,
                'epoch': epoch,
                'counter_dict': count_dict
            }

            torch.save(model_state, save_path)
    print("=>Best accuracy {:.3f}".format(model_state['best_prec1']))
    torch.save(model_state, save_path)
示例#8
0
# Model
print('==> Building model..')
print(compress_rate)
net = eval(args.arch)(compress_rate=compress_rate, num_class=args.num_class)
net = net.to(device)

if len(args.gpu) > 1 and torch.cuda.is_available():
    device_id = []
    for i in range((len(args.gpu) + 1) // 2):
        device_id.append(i)
    net = torch.nn.DataParallel(net, device_ids=device_id)

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    checkpoint = get_model('vgg_16_bn_cifar10_hr_pruning.pth', device=device)
    from collections import OrderedDict

    new_state_dict = OrderedDict()
    if args.adjust_ckpt:
        for k, v in checkpoint.items():
            new_state_dict[k.replace('module.', '')] = v
    else:
        for k, v in checkpoint['state_dict'].items():
            new_state_dict[k.replace('module.', '')] = v
    net.load_state_dict(new_state_dict)

criterion = nn.CrossEntropyLoss()
feature_result = torch.tensor(0.)
total = torch.tensor(0.)
示例#9
0
def get_filter_idx(count_dict, compress_rate):
    result_dict ={}
    for k,v in count_dict.items():
        filter_total_number = len(v)
        keys = list(map(lambda x :x[0], v.most_common(int((1- compress_rate)*filter_total_number))))
        result_dict[k] = keys
    return result_dict

def get_last_convidx(net):
    idx = -1
    for m in net.modules():
        if isinstance(m,nn.Conv2d):
            idx+=1
    return idx


# def store_pruned_model(model_name,cop):
#     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#     model = get_model(model_name, device=device)
#     pruned_model_100 = prune_model(model, 0.5)
#     torch.save(pruned_model_100, "D:\\workspace\\prune_paper\\main\\startover\\baseline\\VGG_cifar100_pruned.pth");


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = get_model("resnet_34_cifar100.pth",device=device)
    pruned_model_100 = prune_model(model,0.5)
    torch.save(pruned_model_100,"D:\\workspace\\prune_paper\\main\\startover\\baseline\\VGG_cifar100_pruned.pth");

示例#10
0
                                              (1 - compress_rate) * 100)
            filter_cpunt = get_filter_index(likely_ndarray)
            try:
                count_dict[layer_name] += filter_cpunt
            except KeyError:
                count_dict[layer_name] = filter_cpunt
        # grad_dist_matrix = get_dist_matrix(net, 4)
        #
        # likely_ndarray = get_likely_point(grad_dist_matrix, (1 - args.compress_rate) * 100)
        # filter_count += get_filter_index(likely_ndarray)
    return count_dict


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model_state_dict = get_model('VGG_cifar10.pth', device=device)
    f_idx, count = get_count_statics(model_state_dict['counter_dict'], 20)
    count = list(map(lambda x: int(x / 30), count))
    n_bins = f_idx[:20]  # 总的卷积核的个数按照顺序排
    x = count[:20]  # 每个卷积核被相似的频率数
    # ['bmh', 'classic', 'dark_background', 'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn-bright',
    #  'seaborn-colorblind', 'seaborn-dark-palette', 'seaborn-dark', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted',
    #  'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 'seaborn-ticks',
    #  'seaborn-white', 'seaborn-whitegrid', 'seaborn', 'Solarize_Light2', 'tableau-colorblind10', '_classic_test']
    plt.style.use(['bmh'])
    #梯度前20名柱状图,10min
    plt.title('Gradient similarity frequency after 10 mini-batches')
    n_bins = [
        16, 28, 39, 159, 178, 222, 239, 254, 283, 293, 296, 301, 306, 316, 321,
        405, 408, 451, 452, 508
    ]