Example #1
0
def test():
    net = MobileNet(amc=True)
    from compute_flops import print_model_param_nums, print_model_param_flops
    #x = torch.randn(1,3,224,224)
    #y = net(x)
    #print(y.size())
    print_model_param_nums(net)
    print_model_param_flops(net)
Example #2
0
def main():
    global args, best_prec1, device
    args = parser.parse_args()

    batch_size = args.batch_size * max(1, args.num_gpus)
    args.lr = args.lr * (batch_size / 256.)
    print(batch_size, args.lr, args.num_gpus)

    num_classes = 1000
    num_training_samples = 1281167
    args.num_batches_per_epoch = num_training_samples // batch_size

    assert os.path.isfile(args.load) and args.load.endswith(".pth.tar")
    args.save = os.path.dirname(args.load)
    training_mode = 'retrain' if args.retrain else 'finetune'
    args.save = os.path.join(args.save, training_mode)

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    args.model_save_path = os.path.join(
        args.save, "epochs_{}_{}".format(args.epochs,
                                         os.path.basename(args.load)))
    args.distributed = args.world_size > 1

    ##########################################################
    ## create file handler which logs even debug messages
    #import logging
    #log = logging.getLogger()
    #log.setLevel(logging.INFO)

    #ch = logging.StreamHandler()
    #fh = logging.FileHandler(args.logging_file_path)

    #formatter = logging.Formatter('%(asctime)s - %(message)s')
    #ch.setFormatter(formatter)
    #fh.setFormatter(formatter)
    #log.addHandler(fh)
    #log.addHandler(ch)
    ##########################################################

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # Use CUDA
    args.use_cuda = torch.cuda.is_available() and not args.no_cuda

    # Random seed
    random.seed(0)
    torch.manual_seed(0)
    if args.use_cuda:
        torch.cuda.manual_seed_all(0)
        device = 'cuda'
        cudnn.benchmark = True
    else:
        device = 'cpu'

    if args.evaluate == 1:
        device = 'cuda:0'

    assert os.path.isfile(args.load)
    print("=> loading checkpoint '{}'".format(args.load))
    checkpoint = torch.load(args.load)

    model = mobilenetv2(cfg=checkpoint['cfg'])
    cfg = model.cfg

    total_params = print_model_param_nums(model.cpu())
    total_flops = print_model_param_flops(model.cpu(),
                                          224,
                                          multiply_adds=False)
    print(total_params, total_flops)

    if not args.distributed:
        model = torch.nn.DataParallel(model).to(device)
    else:
        model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(model)

    ##### finetune #####
    if not args.retrain:
        model.load_state_dict(checkpoint['state_dict'])

    # define loss function (criterion) and optimizer
    if args.label_smoothing:
        criterion = CrossEntropyLabelSmooth(num_classes).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    ### all parameter ####
    no_wd_params, wd_params = [], []
    for name, param in model.named_parameters():
        if param.requires_grad:
            if ".bn" in name or '.bias' in name:
                no_wd_params.append(param)
            else:
                wd_params.append(param)
    no_wd_params = nn.ParameterList(no_wd_params)
    wd_params = nn.ParameterList(wd_params)

    optimizer = torch.optim.SGD([
        {
            'params': no_wd_params,
            'weight_decay': 0.
        },
        {
            'params': wd_params,
            'weight_decay': args.weight_decay
        },
    ],
                                args.lr,
                                momentum=args.momentum)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.model_save_path):
            print("=> loading checkpoint '{}'".format(args.model_save_path))
            checkpoint = torch.load(args.model_save_path)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.model_save_path, checkpoint['epoch']))
        else:
            pass

    # Data loading code
    train_loader, val_loader = \
        get_data_loader(args.data, train_batch_size=batch_size, test_batch_size=32, workers=args.workers)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        #adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'cfg': cfg,
                #'m': args.m,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            },
            args.model_save_path)

        print('  + Number of params: %.3fM' % (total_params / 1e6))
        print('  + Number of FLOPs: %.3fG' % (total_flops / 1e9))
Example #3
0
if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = fix_robustness_ckpt(torch.load(args.model))
        # args.start_epoch = checkpoint['epoch']
        # best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint, strict=False)
        # print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
        #       .format(args.model, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
        exit()

if args.dataset == 'imagenet':
    print('original model param: ', print_model_param_nums(model))
    print('original model flops: ', print_model_param_flops(model, 224, True))
else:
    print('original model param: ', print_model_param_nums(model))
    print('original model flops: ', print_model_param_flops(model, 32, True))

if args.cuda:
    model.cuda()

total = 0

for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]

bn = torch.zeros(total)
Example #4
0
        output = model(data)
        test_loss += F.cross_entropy(
            output, target, size_average=False).data  # sum up batch loss
        pred = output.data.max(
            1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().numpy().sum()

    test_loss /= len(test_loader.dataset)
    #print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
    #    test_loss, correct, len(test_loader.dataset),
    #    100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))


acc = test(model)

total_params = print_model_param_nums(model.cpu())
total_flops = print_model_param_flops(model.cpu(), 32)

results = {
    'load': args.load,
    'dataset': args.dataset,
    'model_name': args.model_name,
    'arch': 'mobilenetv1',
    'acc': acc,
    'cfg': model.cfg,
    'total_params': total_params,
    'total_flops': total_flops,
}
print(results)
def main():
    global best_prec1, log

    batch_size = args.batch_size * max(1, args.num_gpus)
    args.lr = args.lr * (batch_size // 256)
    print(batch_size, args.lr, args.num_gpus)

    num_classes = 1000
    num_training_samples = 1281167
    args.num_batches_per_epoch = num_training_samples // batch_size

    assert args.exp_name
    args.save = os.path.join(args.save, args.exp_name)
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    hyper_str = "run_{}_lr_{}_decay_{}_b_{}_gpu_{}".format(args.epochs, args.lr, \
                                args.lr_mode, batch_size, args.num_gpus)

    ## bn-based pruning base model ##
    if args.sr:
        hyper_str = "{}_sr_grow_{}_s_{}".format(hyper_str, args.m, args.s)
    ## using amc configuration ##
    elif args.amc:
        hyper_str = "{}_amc".format(hyper_str)
    elif args.sp:
        hyper_str = "{}_sp_base_{}".format(hyper_str, args.sp_cfg)
    else:
        hyper_str = "{}_grow_{}".format(hyper_str, args.m)

    args.model_save_path = \
            os.path.join(args.save, 'mbv1_{}.pth.tar'.format(hyper_str))

    #args.logging_file_path = \
    #        os.path.join(args.save, 'mbv1_{}.log'.format(hyper_str))
    #print(args.model_save_path, args.logging_file_path)

    ##########################################################
    ## create file handler which logs even debug messages
    #import logging
    #log = logging.getLogger()
    #log.setLevel(logging.INFO)

    #ch = logging.StreamHandler()
    #fh = logging.FileHandler(args.logging_file_path)

    #formatter = logging.Formatter('%(asctime)s - %(message)s')
    #ch.setFormatter(formatter)
    #fh.setFormatter(formatter)
    #log.addHandler(fh)
    #log.addHandler(ch)
    #########################################################
    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # Use CUDA
    use_cuda = torch.cuda.is_available()
    args.use_cuda = use_cuda

    # Random seed
    random.seed(0)
    torch.manual_seed(0)
    if use_cuda:
        torch.cuda.manual_seed_all(0)
        device = 'cuda'
        cudnn.benchmark = True
    else:
        device = 'cpu'

    if args.evaluate == 1:
        device = 'cuda:0'

    if args.sp:
        model = mbnet(default=args.sp_cfg)
    else:
        #model = mobilenetv1(amc=args.amc, m=args.m)
        model = mbnet(amc=args.amc, m=args.m)
        print(model.cfg)

    cfg = model.cfg

    total_params = print_model_param_nums(model.cpu())
    total_flops = print_model_param_flops(model.cpu(),
                                          224,
                                          multiply_adds=False)
    print(total_params, total_flops)

    if not args.distributed:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    if args.label_smoothing:
        criterion = CrossEntropyLabelSmooth(num_classes).cuda()
    else:
        criterion = nn.CrossEntropyLoss().cuda()

    ### all parameter ####
    no_wd_params, wd_params = [], []
    for name, param in model.named_parameters():
        if param.requires_grad:
            if ".bn" in name or '.bias' in name:
                no_wd_params.append(param)
            else:
                wd_params.append(param)
    no_wd_params = nn.ParameterList(no_wd_params)
    wd_params = nn.ParameterList(wd_params)

    optimizer = torch.optim.SGD([
        {
            'params': no_wd_params,
            'weight_decay': 0.
        },
        {
            'params': wd_params,
            'weight_decay': args.weight_decay
        },
    ],
                                args.lr,
                                momentum=args.momentum,
                                nesterov=True)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.model_save_path):
            print("=> loading checkpoint '{}'".format(args.model_save_path))
            checkpoint = torch.load(args.model_save_path)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.model_save_path, checkpoint['epoch']))
        else:
            pass
            #print("=> no checkpoint found at '{}'".format(args.model_save_path))

    # Data loading code
    train_loader, val_loader = \
        get_data_loader(args.data, train_batch_size=batch_size, test_batch_size=32, workers=args.workers)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        #adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'cfg': cfg,
                'sr': args.sr,
                'amc': args.amc,
                's': args.s,
                'args': args,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, args.model_save_path)
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        #if idx0.size == 0: continue
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1, ))
        fc_weight = m0.weight.data.clone()
        fc_weight[:, idx0.tolist()] /= 2.
        fc_bias = m0.bias.data.clone()
        m1.weight.data = torch.cat((fc_weight, fc_weight[:, idx0.tolist()]), 1)
        m1.bias.data = m0.bias.data.clone()

log.info(model.cfg)
log.info(newmodel.cfg)

print('acc after splitting')
test(newmodel)

#if not args.debug:
torch.save(
    {
        'cfg': newmodel.cfg,
        'split_index': args.split_index,
        'grow': args.grow,
        'min_eig_vals': min_eig_vals,
        'state_dict': newmodel.state_dict(),
        'args': args,
    }, os.path.join(args.save, model_save_path))

print(os.path.join(args.save, model_save_path))
new_num_parameters = print_model_param_nums(newmodel.cpu())
new_num_flops = print_model_param_flops(newmodel.cpu(), 32)
else:
    criterion = nn.CrossEntropyLoss().cuda()

# Data loading code
train_loader, val_loader = \
    get_data_loader(args.data, train_batch_size=args.batch_size, test_batch_size=16, workers=args.workers)


## loading pretrained model ##
assert args.load
assert os.path.isfile(args.load)
print("=> loading checkpoint '{}'".format(args.load))
checkpoint = torch.load(args.load)

model = mbnet(cfg=checkpoint['cfg'])
total_params = print_model_param_nums(model)
total_flops = print_model_param_flops(model, 224, multiply_adds=False) 
print(total_params, total_flops)

if args.use_cuda: 
    model.cuda()

selected_model_keys = [k for k in model.state_dict().keys() if not (k.endswith('.y') or k.endswith('.v') or k.startswith('net_params') or k.startswith('y_params') or k.startswith('v_params'))]
saved_model_keys = checkpoint['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
if len(selected_model_keys) == len(saved_model_keys):

    for k0, k1 in zip(selected_model_keys, saved_model_keys):
        new_state_dict[k0] = checkpoint['state_dict'][k1]   
    
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        #checkpoint = torch.load(args.model)
        saved_state_dict = torch.load(args.model)
        model.load_state_dict(saved_state_dict)

        #args.start_epoch = checkpoint['epoch']
        #best_prec1 = checkpoint['best_prec1']
        #model.load_state_dict(checkpoint['state_dict'])
    else:
        print("=> no model found at '{}'".format(args.model))

#show original model parameters and flops
print('------------------------------')
print('Original model: ')
compute_flops.print_model_param_nums(model)

#print(model)
total = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
        total += m.weight.data.shape[0]

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size
Example #9
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    if args.resume:
        if not os.path.isdir(args.resume):
            os.makedirs(args.resume)
    log = open(os.path.join(args.save_path, '{}.txt'.format(args.description)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("use cuda: {}".format(args.use_cuda), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)
    print_log("Compress Rate: {}".format(args.rate), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("description: {}".format(args.description), log)

    # Init data loader
    if args.dataset=='cifar10':
        train_loader=dataset.cifar10DataLoader(True,args.batch_size,True,args.workers)
        test_loader=dataset.cifar10DataLoader(False,args.batch_size,False,args.workers)
        num_classes=10
    elif args.dataset=='cifar100':
        train_loader=dataset.cifar100DataLoader(True,args.batch_size,True,args.workers)
        test_loader=dataset.cifar100DataLoader(False,args.batch_size,False,args.workers)
        num_classes=100
    elif args.dataset=='imagenet':
        assert False,'Do not finish imagenet code'
    else:
        assert False,'Do not support dataset : {}'.format(args.dataset)

    # Init model
    if args.arch=='cifarvgg16':
        net=models.vgg16_cifar(True,num_classes)
    elif args.arch=='resnet32':
        net=models.resnet32(num_classes)
    elif args.arch=='resnet56':
        net=models.resnet56(num_classes)
    elif args.arch=='resnet110':
        net=models.resnet110(num_classes)
    else:
        assert False,'Not finished'


    print_log("=> network:\n {}".format(net),log)
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                                weight_decay=state['decay'], nesterov=True)
    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume+'checkpoint.pth.tar'):
            print_log("=> loading checkpoint '{}'".format(args.resume+'checkpoint.pth.tar'), log)
            checkpoint = torch.load(args.resume+'checkpoint.pth.tar')
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            if args.use_state_dict:
                net.load_state_dict(checkpoint['state_dict'])
            else:
                net = checkpoint['state_dict']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), log)

            if args.evaluate:
                time1=time.time()
                validate(test_loader,net,criterion,args.use_cuda,log)
                time2=time.time()
                print('validate function took %0.3f ms' % ((time2 - time1) * 1000.0))
                return
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> not use any checkpoint for {} model".format(args.description), log)

    if args.original_train:
        original_train.args.arch=args.arch
        original_train.args.dataset=args.dataset
        original_train.main()
        return

    comp_rate=args.rate
    m=mask.Mask(net,args.use_cuda)
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("the compression rate now is %f" % comp_rate)

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, args.use_cuda,log)
    print(" accu before is: %.3f %%" % val_acc_1)

    m.model=net
    print('before pruning')
    m.init_mask(comp_rate,args.last_index)
    m.do_mask()
    print('after pruning')
    m.print_weights_zero()
    net=m.model#update net

    if args.use_cuda:
        net=net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, args.use_cuda,log)
    print(" accu after is: %.3f %%" % val_acc_2)
    #

    start_time=time.time()
    epoch_time=AverageMeter()
    for epoch in range(args.start_epoch,args.epochs):
        current_learning_rate=adjust_learning_rate(args.learning_rate,optimizer,epoch,args.gammas,args.schedule)
        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}]'.format(recorder.max_accuracy(False)), log)
        train_acc,train_los=train(train_loader,net,criterion,optimizer,epoch,args.use_cuda,log)
        validate(test_loader, net, criterion,args.use_cuda, log)
        if (epoch % args.epoch_prune == 0 or epoch == args.epochs - 1):
            m.model=net
            print('before pruning')
            m.print_weights_zero()
            m.init_mask(comp_rate,args.last_index)
            m.do_mask()
            print('after pruning')
            m.print_weights_zero()
            net=m.model
            if args.use_cuda:
                net=net.cuda()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion,args.use_cuda,log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2)
        if args.resume:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': net,
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.resume, 'checkpoint.pth.tar')
        print('save ckpt done')

        epoch_time.update(time.time()-start_time)
        start_time=time.time()
    torch.save(net,args.model_save)
    # torch.save(net,args.save_path)
    flops.print_model_param_nums(net)
    flops.count_model_param_flops(net,32,False)
    log.close()