Example #1
0
def calculate_flops(current_model):
    if args.expand:
        if args.arch == "resnet56":
            model_ref = models.resnet_expand.resnet56(num_classes=num_classes)
        else:
            raise NotImplementedError()
    else:
        if re.match("vgg.+", args.arch):
            model_ref = models.__dict__[args.arch](num_classes=num_classes)
        else:
            raise NotImplementedError()
    current_flops = count_model_param_flops(current_model.cpu(), 32)
    ref_flops = count_model_param_flops(model_ref.cpu(), 32)
    flops_ratio = current_flops / ref_flops

    print("FLOPs remains {}".format(flops_ratio))
Example #2
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    args.distributed = args.world_size > 1

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    #################################################################################
    if args.model == 'resnet-2x':
        model = models.resnet_2x()
        model_ref = models.resnet50_official()

    if args.model == 'vgg-5x':
        model = models.vgg_5x()
        model_ref = models.vgg_official()

    flops_std = count_model_param_flops(model_ref, 224)
    flops_small = count_model_param_flops(model, 224)
    ratio = flops_std / flops_small
    if ratio >= 2:
        args.epochs = 180
        step_size = 60
    else:
        args.epochs = int(90 * ratio)
        step_size = int(args.epochs / 3)
    #################################################################################

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, step_size)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save)
Example #3
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log("Compress Rate: {}".format(args.rate), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    net_ref = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    net_ref = torch.nn.DataParallel(net_ref, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            net_ref = checkpoint['state_dict']
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    flops_std = count_model_param_flops(net, 32)
    flops_small = count_model_param_flops(net_ref, 32)

    ratio = flops_std / flops_small
    args.epochs = int(400 * ratio)
    print("Total epochs %d" % args.epochs)
    schedule = args.schedule
    args.schedule = [
        1,
        int(schedule[1] * ratio),
        int(schedule[2] * ratio),
        int(schedule[3] * ratio)
    ]
    print(args.schedule)

    recorder = RecorderMeter(args.epochs)
    ###################################################################################################################
    for m, m_ref in zip(net.modules(), net_ref.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m_ref.weight.data.abs().clone()
            mask = weight_copy.gt(0).float().cuda()
            n = mask.sum() / float(m.in_channels)
            m.weight.data.normal_(0, math.sqrt(2. / n))
            m.weight.data.mul_(mask)
    ###################################################################################################################

    if args.evaluate:
        time1 = time.time()
        validate(test_loader, net, criterion, log)
        time2 = time.time()
        print('function took %0.3f ms' % ((time2 - time1) * 1000.0))
        return

    m = Mask(net)

    m.init_length()

    comp_rate = args.rate
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("the compression rate now is %f" % comp_rate)

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

    print(" accu before is: %.3f %%" % val_acc_1)

    if args.use_cuda:
        net = net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)
    print(" accu after is: %s %%" % val_acc_2)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        num_parameters = get_conv_zero_param(net)
        print_log('Zero parameters: {}'.format(num_parameters), log)
        num_parameters = sum([param.nelement() for param in net.parameters()])
        print_log('Parameters: {}'.format(num_parameters), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log)

        # evaluate on validation set
        val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2,
                                  val_acc_2)

        save_checkpoint(
            {
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    log.close()
        mask = bn3_masks[i]
        assert mask[1].shape[0] == m.expand_layer.idx.shape[0]
        m.expand_layer.idx = np.argwhere(
            mask[1].clone().cpu().numpy()).squeeze()

torch.save(
    {
        'cfg': cfg,
        'state_dict': newmodel.state_dict(),
        "bn3_masks": bn3_masks
    }, os.path.join(args.save, '{}.pth.tar'.format(output_name)))

# print(newmodel)
model = newmodel

flops = count_model_param_flops(model.cuda(), 224)
print("FLOPs after pruning: {}".format(flops))

summary = pruning_summary_resnet50(model, False)
print(summary)

# evaluate model
test(model, args)

with open(savepath, "a") as fp:
    fp.write("FLOPs after pruning: {} \n".format(flops))
    fp.write("\n\n\n")
    fp.write("************MODEL SUMMARY************")
    fp.write(summary)
    fp.write("*************************************")
Example #5
0
            continue
        mask = bn2_masks[i]
        assert mask[1].shape[0] == m.expand_layer.idx.shape[0]
        m.expand_layer.idx = np.argwhere(
            mask[1].clone().cpu().numpy()).squeeze().reshape(-1)

torch.save(
    {
        'cfg': cfg,
        'state_dict': newmodel.state_dict(),
        "bn3_masks": bn2_masks
    }, os.path.join(args.save, '{}.pth.tar'.format(output_name)))

model.enable_aux_fc = False
newmodel.enable_aux_fc = False
flops_ref = count_model_param_flops(model.cpu(), 32)
model = newmodel
flops = count_model_param_flops(model.cpu(), 32)

summary = pruning_summary_resnet56(model, num_classes=num_classes)
print(summary)

pruned_acc = test(model, test_loader)
print("=> Pruned completed. Test acc: {}".format(load_acc))

with open(savepath, "a") as fp:
    fp.write("FLOPs before pruning: {} \n".format(flops_ref))
    fp.write("FLOPs after pruning: {} \n".format(flops))
    fp.write("\n\n\n")
    fp.write("************MODEL SUMMARY************")
    fp.write(summary)
Example #6
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if not os.path.exists(args.save):
        os.maskdit(args.save)

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
        model_ref = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()
        model_ref = models.__dict__[args.arch]()

    ######################################################################################################
    flops_std = count_model_param_flops(model)
    flops_small = count_model_param_flops(model_ref)
    args.epochs = int(90 * flops_std / flops_small)
    step_size = int(args.epochs / 3)
    print("Scratch-B training total epochs %d" % args.epochs)
    ######################################################################################################

    if args.gpu is not None:
        model = model.cuda(args.gpu)
        model_ref = model_ref.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model_ref.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
        model_ref = torch.nn.parallel.DistributedDataParallel(model_ref)
    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model_ref.features = torch.nn.DataParallel(model_ref.features)
            model.cuda()
            model_ref.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
            model_ref = torch.nn.DataParallel(model_ref).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model_ref.load_state_dict(checkpoint['state_dict'])
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # set some weights to zero, according to model_ref ---------------------------------
    for m, m_ref in zip(model.modules(), model_ref.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m_ref.weight.data.abs().clone()
            mask = weight_copy.gt(0).float().cuda()
            n = mask.sum() / float(m.in_channels)
            m.weight.data.normal_(0, math.sqrt(2. / n))
            m.weight.data.mul_(mask)
    # ----------------------------------------------------------------------------------

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, step_size)

        #####################################################################################################
        num_parameters = get_conv_zero_param(model)
        print('Zero parameters: {}'.format(num_parameters))
        num_parameters = sum(
            [param.nelement() for param in model.parameters()])
        print('Parameters: {}'.format(num_parameters))
        #####################################################################################################

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.save)
    return
Example #7
0
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
        weight_copy = m.weight.data.abs().clone()
        mask = weight_copy.gt(thre)
        mask = mask.float().cuda()
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')
compute_flops.count_model_param_flops(model=None, input_res=224, multiply_adds=False)
torch.save({'cfg': cfg, 'state_dict': model.state_dict()}, os.path.join(args.save, 'pruned.pth.tar'))

pruned_ratio = pruned/total

print('Pre-processing Successful!')

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
Example #8
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    if args.resume:
        if not os.path.isdir(args.resume):
            os.makedirs(args.resume)
    log = open(os.path.join(args.save_path, '{}.txt'.format(args.description)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("use cuda: {}".format(args.use_cuda), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)
    print_log("Compress Rate: {}".format(args.rate), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("description: {}".format(args.description), log)

    # Init data loader
    if args.dataset=='cifar10':
        train_loader=dataset.cifar10DataLoader(True,args.batch_size,True,args.workers)
        test_loader=dataset.cifar10DataLoader(False,args.batch_size,False,args.workers)
        num_classes=10
    elif args.dataset=='cifar100':
        train_loader=dataset.cifar100DataLoader(True,args.batch_size,True,args.workers)
        test_loader=dataset.cifar100DataLoader(False,args.batch_size,False,args.workers)
        num_classes=100
    elif args.dataset=='imagenet':
        assert False,'Do not finish imagenet code'
    else:
        assert False,'Do not support dataset : {}'.format(args.dataset)

    # Init model
    if args.arch=='cifarvgg16':
        net=models.vgg16_cifar(True,num_classes)
    elif args.arch=='resnet32':
        net=models.resnet32(num_classes)
    elif args.arch=='resnet56':
        net=models.resnet56(num_classes)
    elif args.arch=='resnet110':
        net=models.resnet110(num_classes)
    else:
        assert False,'Not finished'


    print_log("=> network:\n {}".format(net),log)
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                                weight_decay=state['decay'], nesterov=True)
    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume+'checkpoint.pth.tar'):
            print_log("=> loading checkpoint '{}'".format(args.resume+'checkpoint.pth.tar'), log)
            checkpoint = torch.load(args.resume+'checkpoint.pth.tar')
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            if args.use_state_dict:
                net.load_state_dict(checkpoint['state_dict'])
            else:
                net = checkpoint['state_dict']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), log)

            if args.evaluate:
                time1=time.time()
                validate(test_loader,net,criterion,args.use_cuda,log)
                time2=time.time()
                print('validate function took %0.3f ms' % ((time2 - time1) * 1000.0))
                return
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> not use any checkpoint for {} model".format(args.description), log)

    if args.original_train:
        original_train.args.arch=args.arch
        original_train.args.dataset=args.dataset
        original_train.main()
        return

    comp_rate=args.rate
    m=mask.Mask(net,args.use_cuda)
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("the compression rate now is %f" % comp_rate)

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, args.use_cuda,log)
    print(" accu before is: %.3f %%" % val_acc_1)

    m.model=net
    print('before pruning')
    m.init_mask(comp_rate,args.last_index)
    m.do_mask()
    print('after pruning')
    m.print_weights_zero()
    net=m.model#update net

    if args.use_cuda:
        net=net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, args.use_cuda,log)
    print(" accu after is: %.3f %%" % val_acc_2)
    #

    start_time=time.time()
    epoch_time=AverageMeter()
    for epoch in range(args.start_epoch,args.epochs):
        current_learning_rate=adjust_learning_rate(args.learning_rate,optimizer,epoch,args.gammas,args.schedule)
        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}]'.format(recorder.max_accuracy(False)), log)
        train_acc,train_los=train(train_loader,net,criterion,optimizer,epoch,args.use_cuda,log)
        validate(test_loader, net, criterion,args.use_cuda, log)
        if (epoch % args.epoch_prune == 0 or epoch == args.epochs - 1):
            m.model=net
            print('before pruning')
            m.print_weights_zero()
            m.init_mask(comp_rate,args.last_index)
            m.do_mask()
            print('after pruning')
            m.print_weights_zero()
            net=m.model
            if args.use_cuda:
                net=net.cuda()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion,args.use_cuda,log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2)
        if args.resume:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': net,
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.resume, 'checkpoint.pth.tar')
        print('save ckpt done')

        epoch_time.update(time.time()-start_time)
        start_time=time.time()
    torch.save(net,args.model_save)
    # torch.save(net,args.save_path)
    flops.print_model_param_nums(net)
    flops.count_model_param_flops(net,32,False)
    log.close()
Example #9
0
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1, ))
        if idx1.size == 1:
            idx1 = np.resize(idx1, (1, ))
        w1 = m0.weight.data[:, idx0.tolist()].clone()
        if layer_id_in_cfg != len(cfg_mask):
            w1 = w1[idx1.tolist(), :].clone()
            bias1 = m0.bias.data[idx1.tolist()].clone()
        else:
            bias1 = m0.bias.data.clone()
        assert m1.weight.data.shape == w1.shape
        assert m1.bias.data.shape == bias1.shape
        m1.weight.data = w1.clone()
        m1.bias.data = bias1.clone()

torch.save({
    'cfg': cfg,
    'state_dict': newmodel.state_dict()
}, os.path.join(args.save, 'pruned.pth.tar'))

print(newmodel)
pruned_acc = test(newmodel)
print("Accuracy after pruning: {}".format(pruned_acc))

# calculate FLOPs
base_flops = count_model_param_flops(model, 32)
pruned_flops = count_model_param_flops(newmodel, 32)
flops_ratio = pruned_flops / base_flops
print("Pruning FLOPs: {}".format(flops_ratio))
Example #10
0
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


print("Starting evaluating...")
# acc = test()
print("Skip evaluation. Aborted.")

print("Computing FLOPs...")
print("cfg: ", cfg)

# calculate FLOPs
flops = count_model_param_flops(new_model.cuda(), 224)
flops_unpruned = count_model_param_flops(model.cuda(), 224)
print("FLOPs after pruning: {}".format(flops))
print("FLOPs Unpruned: {}".format(flops_unpruned))