Exemple #1
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    if args.resume:
        if not os.path.isdir(args.resume):
            os.makedirs(args.resume)
    log = open(os.path.join(args.save_path, '{}.txt'.format(args.description)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("use cuda: {}".format(args.use_cuda), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)
    print_log("Compress Rate: {}".format(args.rate), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("description: {}".format(args.description), log)

    # Init data loader
    if args.dataset=='cifar10':
        train_loader=dataset.cifar10DataLoader(True,args.batch_size,True,args.workers)
        test_loader=dataset.cifar10DataLoader(False,args.batch_size,False,args.workers)
        num_classes=10
    elif args.dataset=='cifar100':
        train_loader=dataset.cifar100DataLoader(True,args.batch_size,True,args.workers)
        test_loader=dataset.cifar100DataLoader(False,args.batch_size,False,args.workers)
        num_classes=100
    elif args.dataset=='imagenet':
        assert False,'Do not finish imagenet code'
    else:
        assert False,'Do not support dataset : {}'.format(args.dataset)

    # Init model
    if args.arch=='cifarvgg16':
        net=models.vgg16_cifar(True,num_classes)
    elif args.arch=='resnet32':
        net=models.resnet32(num_classes)
    elif args.arch=='resnet56':
        net=models.resnet56(num_classes)
    elif args.arch=='resnet110':
        net=models.resnet110(num_classes)
    else:
        assert False,'Not finished'


    print_log("=> network:\n {}".format(net),log)
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                                weight_decay=state['decay'], nesterov=True)
    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume+'checkpoint.pth.tar'):
            print_log("=> loading checkpoint '{}'".format(args.resume+'checkpoint.pth.tar'), log)
            checkpoint = torch.load(args.resume+'checkpoint.pth.tar')
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            if args.use_state_dict:
                net.load_state_dict(checkpoint['state_dict'])
            else:
                net = checkpoint['state_dict']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), log)

            if args.evaluate:
                time1=time.time()
                validate(test_loader,net,criterion,args.use_cuda,log)
                time2=time.time()
                print('validate function took %0.3f ms' % ((time2 - time1) * 1000.0))
                return
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> not use any checkpoint for {} model".format(args.description), log)

    if args.original_train:
        original_train.args.arch=args.arch
        original_train.args.dataset=args.dataset
        original_train.main()
        return

    comp_rate=args.rate
    m=mask.Mask(net,args.use_cuda)
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("the compression rate now is %f" % comp_rate)

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, args.use_cuda,log)
    print(" accu before is: %.3f %%" % val_acc_1)

    m.model=net
    print('before pruning')
    m.init_mask(comp_rate,args.last_index)
    m.do_mask()
    print('after pruning')
    m.print_weights_zero()
    net=m.model#update net

    if args.use_cuda:
        net=net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, args.use_cuda,log)
    print(" accu after is: %.3f %%" % val_acc_2)
    #

    start_time=time.time()
    epoch_time=AverageMeter()
    for epoch in range(args.start_epoch,args.epochs):
        current_learning_rate=adjust_learning_rate(args.learning_rate,optimizer,epoch,args.gammas,args.schedule)
        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}]'.format(recorder.max_accuracy(False)), log)
        train_acc,train_los=train(train_loader,net,criterion,optimizer,epoch,args.use_cuda,log)
        validate(test_loader, net, criterion,args.use_cuda, log)
        if (epoch % args.epoch_prune == 0 or epoch == args.epochs - 1):
            m.model=net
            print('before pruning')
            m.print_weights_zero()
            m.init_mask(comp_rate,args.last_index)
            m.do_mask()
            print('after pruning')
            m.print_weights_zero()
            net=m.model
            if args.use_cuda:
                net=net.cuda()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion,args.use_cuda,log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2)
        if args.resume:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': net,
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.resume, 'checkpoint.pth.tar')
        print('save ckpt done')

        epoch_time.update(time.time()-start_time)
        start_time=time.time()
    torch.save(net,args.model_save)
    # torch.save(net,args.save_path)
    flops.print_model_param_nums(net)
    flops.count_model_param_flops(net,32,False)
    log.close()
Exemple #2
0
print("=> creating model '{}'".format(args.arch))

if args.dataset == 'imagenet':

    if args.arch == 'resnet101':
        net = resnet_imagenet.resnet101()
    elif args.arch == 'resnet50':
        net = resnet_imagenet.resnet50()
    elif args.arch == 'resnet34':
        net = resnet_imagenet.resnet34()
    elif args.arch == 'resnet18':
        net = resnet_imagenet.resnet18()
else:
    if args.arch == 'resnet110':
        net = models.resnet110(num_classes=10)
    elif args.arch == 'resnet56':
        net = models.resnet56(num_classes=10)
    elif args.arch == 'resnet32':
        net = models.resnet32(num_classes=10)
    elif args.arch == 'resnet20':
        net = models.resnet20(num_classes=10)

if args.dataset == 'imagenet':

    if args.arch == 'resnet101':
        state_dict = torch.load(
            os.path.join(args.pretrain_path, 'resnet101-5d3b4d8f.pth'))
    elif args.arch == 'resnet50':
        state_dict = torch.load(
            os.path.join(args.pretrain_path, 'resnet50-19c8e357.pth'))
Exemple #3
0
                    type=float,
                    default=0.5,
                    help='beta in adversarial training')
parser.add_argument('--regu',
                    type=str,
                    default='no',
                    help='type of regularization. Possible values are: '
                    'no: no regularization'
                    'random-svd: employ random-svd in regularization ')

if __name__ == "__main__":
    args = parser.parse_args()
    # create model
    n_classes = args.dataset == 'cifar10' and 10 or 100
    if args.model == 'resnet':
        net = resnet110(num_classes=n_classes)
    elif args.model == 'wideresnet':
        net = WideResNet(depth=28,
                         widen_factor=10,
                         dropRate=0.3,
                         num_classes=n_classes)
    elif args.model == 'resnext':
        net = CifarResNeXt(cardinality=8,
                           depth=29,
                           base_width=64,
                           widen_factor=4,
                           nlabels=n_classes)
    else:
        raise Exception('Invalid model name')
    # create optimizer
    optimizer = torch.optim.SGD(net.parameters(),