Ejemplo n.º 1
0
def finetune_network(params, checkpointer, train_loader, test_loader, model, criterion, optimiser) :  
    print('Config,\tEpoch,\tLR,\tTrain_Loss,\tTrain_Top1,\tTrain_Top5,\tTest_Loss,\tTest_Top1,\tTest_Top5')
        
    for epoch in range(params.start_epoch, params.epochs) : 
        params.curr_epoch = epoch
        state = update_lr(params, optimiser)

        losses = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        
        for batch_idx, (inputs, targets) in enumerate(train_loader) : 
            # move inputs and targets to GPU
            if params.use_cuda : 
                inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            # train model
            loss, prec1, prec5 = train(model, criterion, optimiser, inputs, targets)
            
            losses.update(loss) 
            top1.update(prec1) 
            top5.update(prec5)

        params.train_loss = losses.avg        
        params.train_top1 = top1.avg        
        params.train_top5 = top5.avg        

        if params.finetune == True : 
            if (params.prune_weights == True or params.prune_filters == True) and (epoch % 15) == 0 and epoch != 0 : 
            # if (params.prune_weights == True or params.prune_filters == True) and (epoch % 2) == 0 : 
                print('Pruning network')
                model = pruning.prune_model(params, model)
                params.pruning_perc += 10
                print('Pruned Percentage = {}'.format(pruning_utils.prune_rate(params, model)))

            # get test loss of subset on new model
            params.test_loss, params.test_top1, params.test_top5 = inference.test_network(params, test_loader['subset'], model, criterion, optimiser)
            checkpointer.save_checkpoint(model.state_dict(), optimiser.state_dict(), params.get_state(), save_cp=True, config='11')
            print('{},\t{},\t{},\t{},\t{},\t{},\t{},\t{},\t{}'.format('11', epoch, params.lr, params.train_loss, params.train_top1, params.train_top5, params.test_loss, params.test_top1, params.test_top5))
            if params.tbx is not None:
                params.tbx.add_scalar('__'.join(params.sub_classes)+'/top1_subset_on_new_model', params.test_top1, params.curr_epoch)
            
            # get test loss of entire dataset on new model
            params.test_loss, params.test_top1, params.test_top5 = inference.test_network(params, test_loader['orig'], model, criterion, optimiser)
            checkpointer.save_checkpoint(model.state_dict(), optimiser.state_dict(), params.get_state(), save_cp=False, config='01')
            print('{},\t{},\t{},\t{},\t{},\t{},\t{},\t{},\t{}'.format('01', epoch, params.lr, params.train_loss, params.train_top1, params.train_top5, params.test_loss, params.test_top1, params.test_top5))   
            if params.tbx is not None:
                params.tbx.add_scalar('__'.join(params.sub_classes)+'/top1_all_on_new_model', params.test_top1, params.curr_epoch)
        
        else : 
            params.test_loss, params.test_top1, params.test_top5 = inference.test_network(params, test_loader, model, criterion, optimiser)
            checkpointer.save_checkpoint(model.state_dict(), optimiser.state_dict(), params.get_state())
            print('{},\t{},\t{},\t{},\t{},\t{},\t{},\t{},\t{}'.format('00', epoch, params.lr, params.train_loss, params.train_top1, params.train_top5, params.test_loss, params.test_top1, params.test_top5))
Ejemplo n.º 2
0
def filter_prune(params, model):
    '''
    Prune filters one by one until reach pruning_perc
    (not iterative pruning)
    '''
    masks = []
    current_pruning_perc = 0.
    params.pruned_layers = []

    while current_pruning_perc < params.pruning_perc:
        masks = prune_one_filter(params, model, masks)
        model.module.set_masks(masks)
        current_pruning_perc = prune_rate(params, model, verbose=False)

    return masks
Ejemplo n.º 3
0
def filter_prune(model, pruning_perc):
    '''
    Prune filters one by one until reach pruning_perc
    (not iterative pruning)
    '''
    masks = []
    current_pruning_perc = 0.

    while current_pruning_perc < pruning_perc:
        masks = prune_one_filter(model, masks)
        model.set_masks(masks)
        current_pruning_perc = prune_rate(model, verbose=False)
        print('{:.2f} pruned'.format(current_pruning_perc))

    return masks
                            momentum=param['momentum'],
                            weight_decay=param['weight_decay'])

new_net.load_state_dict(net.state_dict())
count = 0
for m in new_net.modules():
    if isinstance(m, nn.Conv2d):
        new_net.mask[count] = torch.zeros_like(m.weight.data)
        count += 1

new_net.set_masks(new_net.mask)

for parameter in net.parameters():
    print(parameter.data)

for parameter in new_net.parameters():
    print(parameter.data)

prune_rate(new_net)
prune_rate(net)
# # Retraining
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.RMSprop(net.parameters(), lr=param['learning_rate'],
#                                 weight_decay=param['weight_decay'])
#
# train(net, criterion, optimizer, param, loader_train, loader_test)
#
prune_rate(net)

torch.save(net.state_dict(), 'models/conv_pruned.pkl')
def main():
    global args, best_prec1
    args = parser.parse_args()
    pruning = False
    chkpoint = False

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        # model = models.__dict__[args.arch](pretrained=True)
        model = alexnet(pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        # model = models.__dict__[args.arch]()
        model = alexnet(pretrained=False)

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            params = {
                k: v
                for k, v in checkpoint['state_dict'].items() if 'mask' not in k
            }
            mask_params = {
                k: v
                for k, v in checkpoint['state_dict'].items() if 'mask' in k
            }
            args.start_epoch = checkpoint['epoch']
            # saved_iter = checkpoint['iter']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(params)
            model.set_masks(list(mask_params.values()))
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            prune_rate(model)
            chkpoint = True
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.prune > 0 and not chkpoint:
        # prune
        print("=> pruning...")
        masks = weight_prune(model, args.prune)
        model.set_masks(masks)
        pruning = True

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'ilsvrc12_train_lmdb_224_pytorch')
    valdir = os.path.join(args.data, 'ilsvrc12_val_lmdb_224_pytorch')
    # traindir = os.path.join(args.data, 'ILSVRC2012_img_train')
    # valdir = os.path.join(args.data, 'ILSVRC2012_img_val_sorted')
    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    # std=[0.229, 0.224, 0.225])

    # train_dataset = datasets.ImageFolder(
    # traindir,
    # transforms.Compose([
    # transforms.RandomResizedCrop(224),
    # transforms.RandomHorizontalFlip(),
    # transforms.ToTensor(),
    # normalize,
    # ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    # train_loader = torch.utils.data.DataLoader(
    # train_dataset, batch_size=args.batch_size, shuffle=(
    # train_sampler is None),
    # num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    train_loader = Loader('train',
                          traindir,
                          batch_size=args.batch_size,
                          num_workers=args.workers,
                          cuda=True)
    val_loader = Loader('val',
                        valdir,
                        batch_size=args.batch_size,
                        num_workers=args.workers,
                        cuda=True)

    # val_loader = torch.utils.data.DataLoader(
    # datasets.ImageFolder(valdir, transforms.Compose([
    # transforms.Resize(256),
    # transforms.CenterCrop(224),
    # transforms.ToTensor(),
    # normalize,
    # ])),
    # batch_size=args.batch_size, shuffle=False,
    # num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return
    if pruning and not chkpoint:
        # Prune weights validation
        print("--- {}% parameters pruned ---".format(args.prune))
        validate(val_loader, model, criterion)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
                'iter': 0,
            },
            is_best,
            path=args.logfolder)

    print("--- After retraining ---")
    prune_rate(model)
    torch.save(model.state_dict(),
               os.path.join(args.logfolder, 'alexnet_pruned.pkl'))
def gen_mask(model,
             loss_fn,
             optimizer,
             param,
             loader_train,
             loader_test,
             ratio,
             k=3,
             loader_val=None):
    test(model, loader_test)

    model.train()
    count = 0
    ratio_ind = 0
    for epoch in range(param['num_epochs']):
        model.train()

        print('Starting epoch %d / %d' % (epoch + 1, param['num_epochs']))
        for t, (x, y) in enumerate(loader_train):
            x_var, y_var = to_var(x), to_var(y.long())

            scores = model(x_var)
            loss = loss_fn(scores, y_var)

            if (t + 1) % 100 == 0:
                print('t = %d, loss = %.8f' % (t + 1, loss.item()))

            optimizer.zero_grad()
            loss.backward()

            model.update_grad()

            optimizer.step()
            # print(epoch,t)
            # test(model, loader_test)

        if (epoch + 1) % k == 0 and ratio_ind < len(ratio):
            print(
                ' pruning some filters which are in convolution layes , pruning ratio:%.3f'
                % ratio[ratio_ind])
            if ratio_ind == 0:
                model.com_mask2(ratio[ratio_ind], 0)
            else:
                model.com_mask2(ratio[ratio_ind], ratio[ratio_ind - 1])
            model.set_masks(model.mask)

            model.zero_accmgrad()
            ratio_ind += 1

        else:
            model.set_masks(model.mask)
        prune_rate(model)

        print('modify learning rate')
        lr = param['learning_rate'] * (0.5**((epoch - k * len(ratio)) // 30))
        # lr = param['learning_rate'] * (0.5 ** ((epoch - 1) // 30))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        print('epoch', epoch)
        test(model, loader_test)
        count += 1
import time
start = time.time()

# ratio_list =[10.0, 20.0, 30.0, 35.5]
# ratio_list =[5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0]
# ratio_list =[10.0, 20.0,30.0, 40.0,  50.0, 60.0]
ratio_list = [0.0]

# ratio_list = [80.0, 90.0]
gen_mask(net, criterion, optimizer, param, loader_train, loader_test,
         ratio_list, 5)

# print('............')
#
#
# # Retraining
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.RMSprop(net.parameters(), lr=param['learning_rate'],
#                                 weight_decay=param['weight_decay'])
#
# train(net, criterion, optimizer, param, loader_train, loader_test)
#

end = time.time()

print(end - start)

prune_rate(net)

torch.save(net.state_dict(), 'models/conv_pruned_60%.pkl')
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
            'masks': masks,
            'masks_amul': masks_amul,
            'masks_act': masks_act,
        },
        is_best,
        filename='vanilla/' + args.type + '_' + crate_str + '_' +
        str(args.epochs) + '_codesign_checkpoint.pth.tar')
    #error source?
    store_txt = 'vanilla/' + args.type + '_' + crate_str + '_' + str(
        args.epochs) + '_codesign.txt'
    with open(store_txt, 'a') as f:
        f.write('{:.4f} {:.4f}'.format(prec1 * 100, prec5 * 100) + '\n')
    normalized_params(model_raw, masks)
    prune_rate(model_raw)
'''
# print sf
acc1, acc5 = misc.eval_model(model_raw, val_ds, ngpu=args.ngpu, is_imagenet=is_imagenet)
res_str = "type={}, quant_method={}, param_bits={}, bn_bits={}, fwd_bits={}, overflow_rate={}, acc1={:.4f}, acc5={:.4f}".format(
    args.type, args.quant_method, args.param_bits, args.bn_bits, args.fwd_bits, args.overflow_rate, acc1, acc5)
print(res_str)
'''

if args.resume_step_one:
    args.resume = args.type + '_' + crate_str + '_' + str(
        10) + '_codesign_checkpoint.pth.tar'
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
Ejemplo n.º 9
0
                                          batch_size=param['test_batch_size'],
                                          shuffle=True)

model = LeNet()
model.load_state_dict(
    torch.load('models/lenet_pretrained.pkl', map_location='cpu'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print("--- Accuracy of Pretrained Model ---")
test(model, loader_test)

# pruning
masks = lenet_prune()
model.set_masks(masks)
print("--- Accuracy After Pruning ---")
test(model, loader_test)

# Retraining
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.parameters(),
                                lr=param['learning_rate'],
                                weight_decay=param['weight_decay'])
train(model, criterion, optimizer, param, loader_train)

print("--- Accuracy After Retraining ---")
test(model, loader_test)
prune_rate(model)

# Save and load the entire model
torch.save(model.state_dict(), 'models/lenet_pruned.pkl')