Пример #1
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Data

    if 'cifar' in args.dataset:
        print('==> Preparing cifar dataset %s' % args.dataset)
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        if args.dataset == 'cifar10':
            dataloader = datasets.CIFAR10
            num_classes = 10
        else:
            dataloader = datasets.CIFAR100
            num_classes = 100


        trainset = dataloader(root=os.path.join(homeDir,args.dataset), train=True, download=True, transform=transform_train)

        train_sampler=None
        toShuffle = True
        if args.subsample < 1:
            toShuffle=False
            n = int(float(len(trainset)) * args.subsample)
            assert n > 0,'must sample a positive number of training examples.'
            train_sampler = data.sampler.SubsetRandomSampler(range(n))
            print('==>SAMPLING FIRST',n,'TRAINING IMAGES')


        if args.class_subset != '_':
            print('*'+args.class_subset+'*')
            toShuffle = False
            args.class_subset = [int(i) for i in args.class_subset.split('_')]
            indices = [i for i, p in enumerate(trainset.train_labels) if p in args.class_subset]
            train_sampler = data.sampler.SubsetRandomSampler(indices);
        trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=toShuffle, num_workers=args.workers,sampler=train_sampler)

        testset = dataloader(root=os.path.join(homeDir ,args.dataset), train=False, download=False, transform=transform_test)
        test_sampler=None
        if args.test_subsample < 1:
            n = int(float(len(testset)) * args.test_subsample)
            assert n > 0, 'must sample a positive number of training examples.'
            test_sampler = data.sampler.SubsetRandomSampler(range(n))
            print('==>SAMPLING FIRST', n, 'TESTTING IMAGES')

        if type(args.class_subset) is list:
            toShuffle = False
            indices = [i for i, p in enumerate(testset.test_labels) if p in args.class_subset]
            test_sampler = data.sampler.SubsetRandomSampler(indices);

        testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers,sampler=test_sampler)
    else:
        train_dir = '/home/amir/data/imagenet12/train'
        test_dir = '/home/amir/data/imagenet12/val'
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        transform_train = transforms.Compose([transforms.RandomCrop(64),transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),normalize])
        transform_test=transforms.Compose([transforms.CenterCrop((64,64)), transforms.ToTensor(), normalize])
            
        trainloader = DataLoader(dataset=datasets.ImageFolder(root = train_dir, transform=transform_train),
                    batch_size=args.train_batch, shuffle=True,num_workers=1)
        
        testloader = DataLoader(dataset=datasets.ImageFolder(root = test_dir, transform=transform_test),
                    batch_size=64, shuffle=True)        

    # Model   
    print("==> creating model '{}'".format(args.arch))
    if 'cifar' in args.dataset:
        if args.arch.startswith('resnext'):
            model = models.__dict__[args.arch](
                        cardinality=args.cardinality,
                        num_classes=num_classes,
                        depth=args.depth,
                        widen_factor=args.widen_factor,
                        dropRate=args.drop,
                    )
        elif args.arch.startswith('densenet'):
            if 'partial' in args.arch:
                model = models.__dict__[args.arch](
    	                    num_classes=num_classes,
    	                    depth=args.depth,
    	                    growthRate=args.growthRate,
    	                    compressionRate=args.compressionRate,
    	                    dropRate=args.drop,part=args.part, zero_fixed_part=args.zero_fixed_part,do_init=True,
                    split_dim = args.dim_slice
    	                )   
            else:
                model = models.__dict__[args.arch](
    	                    num_classes=num_classes,
    	                    depth=args.depth,
    	                    growthRate=args.growthRate,
    	                    compressionRate=args.compressionRate,
    	                    dropRate=args.drop,lateral_inhibition=args.lateral_inhibition
    	                )        
        elif args.arch.startswith('wrn'):
            if 'partial' in args.arch:
                print('==> initializing partial learning with p=',args.part)

                print('classes',num_classes,'depth',args.depth,'widen',args.widen_factor,'drop',args.drop,'part:',args.part)

                model = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                    widen_factor=args.widen_factor,
                    dropRate=args.drop, part=args.part, zero_fixed_part=args.zero_fixed_part
                )
            else:
                model = models.__dict__[args.arch](
                            num_classes=num_classes,
                            depth=args.depth,
                            widen_factor=args.widen_factor,
                            dropRate=args.drop,lateral_inhibition=args.lateral_inhibition
                        )
        elif args.arch.endswith('resnet'):
            model = models.__dict__[args.arch](
                        num_classes=num_classes,
                        depth=args.depth,
                    )
        else:
            if 'partial' in args.arch:
                print ('PARTIAL!!!!',args.arch)

                #print '!!!!!!!!!!!!!!!!1'
                model = models.__dict__[args.arch](num_classes=num_classes,part=args.part,zero_fixed_part=args.zero_fixed_part)
            else:
                print('BOOYAH---------------!!')
                model = models.__dict__[args.arch](num_classes=num_classes)
    else: # must be imagenet
        if 'partial' not in args.arch:
            model = imagenet_models.__dict__[args.arch](depth=28, widen_factor=4, num_classes=1000)
        else:
            model = imagenet_models.__dict__[args.arch](depth=28, widen_factor=4, num_classes=1000,part=args.part,zero_fixed_part=args.zero_fixed_part)



   # hack hack
    #print('==============================', arch,'===============')
    #if 'squeeze' in args.arch:
    #    model.classifier = nn.Sequential(nn.Dropout(.5), nn.Conv2d(512, 10, kernel_size=(1, 1), stride=(1, 1)),
    #                                     nn.ReLU())
    #    print(model)
    from copy import deepcopy
    model = torch.nn.DataParallel(model).cuda()
    if args.load_fixed_path != '':
        # load the (presumably) full model dict.
        print('ensembling - loading old dict')
        fixed_model_dict = torch.load(args.load_fixed_path)['state_dict']

        #if 'partial'

        if 'partial' in args.arch:
            print('transerring to new and splitting')
            split_model(fixed_model_dict,model)
        else: # just load the dictionary as is and continue from this point.
            model.load_state_dict(fixed_model_dict)
    if False:
        if args.load_fixed_path != '':
            print('ENSEMBLING')
            # load the fixed part of this classifier for ensembling

            fixed_model_dict = torch.load(args.load_fixed_path)['state_dict']
            model_dict = model.state_dict()

            if args.part == -1:
                print('-------------BABU-----------------')
                model.load_state_dict(fixed_model_dict)

            elif args.only_layer != 'none': # just copy everything, the layer will be reinitialized layer.
                for a,b in model_dict.items():
                    if args.only_layer not in a:
                        model_dict[a] = deepcopy(fixed_model_dict[a])
                model.load_state_dict(model_dict)
            else:
                for a, b in model_dict.items():
                    # transfer all fixed values from loaded dictionary.
                    if args.part < .5: # re-train learned part.
                        if 'fixed' in a:
                            model_dict[a] = deepcopy(fixed_model_dict[a])
                    else: # re-train what was at first the random part :-)
                        if 'learn' in a:
                            model_dict[a] = deepcopy(fixed_model_dict[a])
                model.load_state_dict(model_dict)

                if args.part >= .5: # switch training between fixed / learned parts.
                    print('HAHA, SWITCHING FIXED AND LEARNING')
                    for a,b in model.module.named_parameters():
                        if 'learn' in a:
                            b.requires_grad = False
                        else:
                            b.requires_grad = True
                        # otherwise, keep it as it is.
    
    assert not (args.retrain_layer != 'none' and args.only_layer != 'none'),'retrain-layer and only-layer options are mutually exclusive'
    
    if args.retrain_layer != 'none':        
        initial_dict = deepcopy(model.state_dict())
        
    
    cudnn.benchmark = True

    #model = models.squeezenet1_1()
    #

    criterion = nn.CrossEntropyLoss()
    opt_ = args.optimizer.lower()
    if args.only_layer != 'none':
        model = only_layer(model,args.only_layer)


    # apply the learn-bn.
    for m1,m2 in model.named_modules():
        if 'bn' in m1:
            for p in m2.parameters():
                p.requires_grad = args.learn_bn
    if args.learn_inhibition:
        for p in model.module.parameters():
            p.requires_grad=True
    
    params = trainableParams(model)
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters() if p.requires_grad)/1000000.0))
    if opt_ == 'sgd':
        print('optimizer.... - sgd')
        optimizer = optim.SGD(params , lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif opt_ == 'adam':
        optimizer = optim.Adam(params)
    elif opt_ == 'yf':
        print('USING YF OPTIMIZER')
        optimizer = YFOptimizer(
            params, lr=args.lr, mu=0.0, weight_decay=args.weight_decay, clip_thresh=2.0, curv_win_width=20)
        optimizer._sparsity_debias = False
    else:
        raise Exception('unsupported optimizer type',opt_)

    nParamsPath = os.path.join(args.checkpoint, 'n_params.txt')
    with open(nParamsPath, 'w') as f:
        s1 = 'active_params {} \n'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))
        f.write(s1)
        s2 = 'total_params {} \n'.format(sum(p.numel() for p in model.parameters()))
        f.write(s2)
    if args.print_params_and_exit:
        exit()

    # Resume
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoinxt..')
        assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
        #args.checkpoint = os.path.dirname(args.checkpoint)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        #start_epoch = checkpoint['epoch']
        start_epoch = args.start_epoch
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        
                
        if args.retrain_layer!='none':            
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=False)
            logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])
            params = trainableParams(model)
            print('number of trainable params:',len(list(params)))
            model = reinit_model_layer(model,args.retrain_layer,initial_dict)
            params = trainableParams(model)
            print('number of trainable params:',len(list(params)))
            optimizer = optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        else:
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=False) # Was True
            logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])
            
    else:
        
            
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])


    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val

    #scheduler = CosineAnnealingLR( optimizer, T_max=args.epochs)#  eta_min = 1e-9, last_epoch=args.epochs)
    if args.half:
        model = model.half()

    for epoch in range(start_epoch, args.epochs):
        if args.sgdr > 0:
            #raise Exception('currently not supporting sgdr')
            scheduler.step()
        else:
            adjust_learning_rate(optimizer, epoch)
        if type(optimizer) is YFOptimizer:
            print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, optimizer.get_lr_factor()))  # state['lr']))
        else:
            print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, optimizer.param_groups[0]['lr']))# state['lr']))

        train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda)
        test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda)
        #if req_perf_after_10_epochs > -1

        # append logger file
        logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)

        if epoch % 10 == 0: # save each 10 epochs anyway
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
            }, is_best, checkpoint=args.checkpoint,filename='checkpoint.pth.tar_'+str(epoch).zfill(4))

        save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
            }, is_best, checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')
    print(best_acc)
        #if req_perf_after_10_epochs > -1

        # append logger file
        logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)

        if epoch % 10 == 0: # save each 10 epochs anyway
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
            }, is_best, checkpoint=args.checkpoint,filename='checkpoint.pth.tar_'+str(epoch).zfill(4))

        save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
            }, is_best, checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')