def test():
    # define default variables
    args = get_args()# divide args part and call it as function
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]
    state = {k: v for k, v in args._get_kwargs()}

    # prepare test data parts
    test_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)])
    test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
    if args.dataset == 'cifar10':
        test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
        nlabels = 10
    else:
        test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
        nlabels = 100

    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_bs, shuffle=False,
                                              num_workers=args.prefetch, pin_memory=True)

    # initialize model and load from checkpoint
    net = CifarResNeXt(args.cardinality, args.depth, nlabels, args.widen_factor)
    loaded_state_dict = torch.load(args.load,map_location='cpu')
    temp = {}
    for key, val in list(loaded_state_dict.iteritems()):
        # parsing keys for ignoring 'module.' in keys
        temp[key[7:]] = val
    loaded_state_dict = temp
    net.load_state_dict(loaded_state_dict)

    # paralleize model 
    if args.ngpu > 1:
        net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    if args.ngpu > 0:
        net.cuda()
   
    # use network for evaluation 
    net.eval()

    # calculation part
    loss_avg = 0.0
    correct = 0.0
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = torch.autograd.Variable(data), torch.autograd.Variable(target)

        # forward
        output = net(data).detach()
        print(output)
        loss = F.cross_entropy(output, target)

        # accuracy
        pred = output.data.max(1)[1]
        correct += pred.eq(target.data).sum()
        print("total: {0}, correct: {1}".format((batch_idx + 1) * args.test_bs, correct[0]))

        # test loss average
        loss_avg += loss.data[0]

    state['test_loss'] = loss_avg / len(test_loader)
    state['test_accuracy'] = correct / len(test_loader.dataset)

    # finally print state dictionary
    print(state)
Beispiel #2
0
def main():
    args = parse_args()
    print(colored("Setting default tensor type to cuda.FloatTensor", "cyan"))
    torch.multiprocessing.set_start_method('spawn')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)
    distributed = len(gpus) > 1
    #device = torch.device('cuda:{}'.format(args.local_rank))

    model = eval('models.' + config.MODEL.NAME +
                 '.get_contrastive_net')(config).cuda()

    dump_input = torch.rand(config.TRAIN.BATCH_SIZE_PER_GPU, 3,
                            config.MODEL.IMAGE_SIZE[1],
                            config.MODEL.IMAGE_SIZE[0]).cuda()
    logger.info(get_model_summary(model, dump_input))

    if config.TRAIN.MODEL_FILE:
        model.load_state_dict(torch.load(config.TRAIN.MODEL_FILE))
        logger.info(
            colored('=> loading model from {}'.format(config.TRAIN.MODEL_FILE),
                    'red'))

    #if args.local_rank == 0:
    # copy model file
    this_dir = os.path.dirname(__file__)
    models_dst_dir = os.path.join(final_output_dir, 'models')
    if os.path.exists(models_dst_dir):
        shutil.rmtree(models_dst_dir)
    shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir)
    """
    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://",
        )
    """

    torch.cuda.empty_cache()
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()
    print("Finished constructing encoder!")

    # define loss function (criterion) and optimizer
    info_nce = InfoNCE(config.CONTRASTIVE.TAU, config.CONTRASTIVE.NORMALIZE,
                       config.CONTRASTIVE.NUM_SAMPLES)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = get_optimizer(config, model)
    lr_scheduler = None

    best_perf = 0.0
    best_model = False
    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_perf = checkpoint['perf']
            model.module.load_state_dict(checkpoint['state_dict'])

            # Update weight decay if needed
            checkpoint['optimizer']['param_groups'][0][
                'weight_decay'] = config.TRAIN.WD
            optimizer.load_state_dict(checkpoint['optimizer'])

            if 'lr_scheduler' in checkpoint:
                lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    optimizer,
                    1e5,
                    last_epoch=checkpoint['lr_scheduler']['last_epoch'])
                lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
            best_model = True

    # Data loading code
    dataset_name = config.DATASET.DATASET

    if dataset_name == 'imagenet':
        # implement imagenet later, this is not supported right now
        traindir = os.path.join(config.DATASET.ROOT + '/',
                                config.DATASET.TRAIN_SET)
        valdir = os.path.join(config.DATASET.ROOT + '/',
                              config.DATASET.TEST_SET)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transform_train = MultiSample(aug_transform(
            config.MODEL.IMAGE_SIZE[0],
            base_transform((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            config.DATASET.AUGMENTATIONS),
                                      n=config.CONTRASTIVE.NUM_SAMPLES)
        transform_valid = transforms.Compose([
            transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)),
            transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]),
            base_transform((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        train_dataset = datasets.ImageFolder(traindir, transform_train)
        valid_dataset = datasets.ImageFolder(valdir, transform_valid)
    else:
        # only cifar10 runs right now
        #assert dataset_name == "cifar10", "Only CIFAR-10 is supported at this phase"
        #classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')  # For reference

        # train transform should build contrastive batch
        transform_train = MultiSample(aug_transform(
            32, base_transform(), config['DATASET']['AUGMENTATIONS']),
                                      n=config['CONTRASTIVE']['NUM_SAMPLES'])
        transform_valid = base_transform()
        if dataset_name == 'cifar10':
            train_dataset = datasets.CIFAR10(root=f'{config.DATASET.ROOT}',
                                             train=True,
                                             download=True,
                                             transform=transform_train)
            valid_dataset = datasets.CIFAR10(root=f'{config.DATASET.ROOT}',
                                             train=False,
                                             download=True,
                                             transform=transform_valid)
        elif dataset_name == 'cifar100':
            train_dataset = datasets.CIFAR100(root=f'{config.DATASET.ROOT}',
                                              train=True,
                                              download=True,
                                              transform=transform_train)
            valid_dataset = datasets.CIFAR100(root=f'{config.DATASET.ROOT}',
                                              train=False,
                                              download=True,
                                              transform=transform_valid)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus),
        shuffle=True,
        num_workers=config.WORKERS,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    # Learning rate scheduler
    if lr_scheduler is None:
        if config.TRAIN.LR_SCHEDULER != 'step':
            lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                len(train_loader) * config.TRAIN.END_EPOCH,
                eta_min=1e-6)
        elif isinstance(config.TRAIN.LR_STEP, list):
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
                last_epoch - 1)
        else:
            lr_scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
                last_epoch - 1)

    # Training code
    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        topk = (1, 5) if dataset_name == 'imagenet' else (1, )
        if config.TRAIN.LR_SCHEDULER == 'step':
            lr_scheduler.step()

        # train for one epoch
        train(config, train_loader, model, info_nce, criterion, optimizer,
              lr_scheduler, epoch, final_output_dir, tb_log_dir)
        torch.cuda.empty_cache()

        # evaluate on validation set
        #perf_indicator = validate(config, valid_loader, model, criterion, lr_scheduler, epoch,
        #final_output_dir, tb_log_dir)
        torch.cuda.empty_cache()
        writer_dict['writer'].flush()
        """
        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False
        """

        best_model = True

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': config.MODEL.NAME,
                'state_dict': model.module.state_dict(),
                #'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
            },
            best_model,
            final_output_dir,
            filename='checkpoint.pth.tar')

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
                                           transform=train_transform)
            test_dataset = datasets.CIFAR10(args.data_dir, train=False, download=False,
                                          transform=test_transform)
        elif args.model == 'AlexNet':

            train_transform, test_transform = get_data_transform('cifar')

            if args.data_name == 'cifar10':
                model = AlexNetForCIFAR()
                train_dataset = datasets.CIFAR10(args.data_dir, train=True, download=False,
                                                 transform=train_transform)
                test_dataset = datasets.CIFAR10(args.data_dir, train=False, download=False,
                                                transform=test_transform)
            else:
                model = AlexNetForCIFAR(num_classes=100)
                train_dataset = datasets.CIFAR100(args.data_dir, train=True, download=False,
                                                  transform=train_transform)
                test_dataset = datasets.CIFAR100(args.data_dir, train=False, download=False,
                                                 transform=test_transform)
        elif args.model == 'ResNet18OnCifar10':
            model = ResNetOnCifar10.ResNet18()

            train_transform, test_transform = get_data_transform('cifar')
            train_dataset = datasets.CIFAR10(args.data_dir, train=True, download=True,
                                             transform=train_transform)
            test_dataset = datasets.CIFAR10(args.data_dir, train=False, download=True,
                                            transform=test_transform)
        elif args.model == 'ResNet34':
            model = models.resnet34(pretrained=False)

            train_transform = transforms.Compose([
                transforms.ToTensor(),
# Normalization for CIFAR10 dataset.
normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447],
                                 std=[0.247, 0.243, 0.262])

print('==> Preparing data..')
transforms_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), normalize
])

transforms_test = transforms.Compose([transforms.ToTensor(), normalize])

train_dataset = datasets.CIFAR100('../data',
                                  train=True,
                                  download=True,
                                  transform=transforms_train)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_dataset = datasets.CIFAR100('../data',
                                 train=False,
                                 download=True,
                                 transform=transforms_test)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

num_class = 100
image_channels = 3
Beispiel #5
0
def main():
    global args, best_prec1

    if args.seed is None:
        args.seed = int(time.time())
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    R = 32
    if args.data == 'cifar10':
        args.num_classes = 10
    elif args.data == 'cifar100':
        args.num_classes = 100
    else:
        args.num_classes = 1000
        R = 224

    if 'densenet' in args.model:
        args.stages = list(map(int, args.stages.split('-')))
        args.growth = list(map(int, args.growth.split('-')))

    ### Calculate FLOPs & Param
    model = getattr(models, args.model)(args)
    n_flops, n_params = measure_model(model, R, R)
    print('FLOPs: %.2fM, Params: %.2fM' % (n_flops / 1e6, n_params / 1e6))

    os.makedirs(args.savedir, exist_ok=True)
    log_file = os.path.join(args.savedir, "%s_%d_%d.txt" % \
        (args.model, int(n_params), int(n_flops)))
    del (model)

    ### Create model
    model = getattr(models, args.model)(args)
    model = torch.nn.DataParallel(model).cuda()

    ### Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    cudnn.benchmark = True

    ### Data loading
    if args.data == "cifar10":
        normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467],
                                         std=[0.2471, 0.2435, 0.2616])
        train_set = datasets.CIFAR10(args.datadir,
                                     train=True,
                                     download=True,
                                     transform=transforms.Compose([
                                         transforms.RandomCrop(32, padding=4),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         normalize,
                                     ]))
        val_set = datasets.CIFAR10(args.datadir,
                                   train=False,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       normalize,
                                   ]))
    elif args.data == "cifar100":
        normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
                                         std=[0.2675, 0.2565, 0.2761])
        train_set = datasets.CIFAR100(args.datadir,
                                      train=True,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          normalize,
                                      ]))
        val_set = datasets.CIFAR100(args.datadir,
                                    train=False,
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
                                        normalize,
                                    ]))
    else:  #imagenet
        traindir = os.path.join(args.datadir, 'train')
        valdir = os.path.join(args.datadir, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_set = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_set = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    ### Optionally resume from a checkpoint
    args.start_epoch = 0
    if args.resume or (args.evaluate is not None):
        checkpoint = load_checkpoint(args)
        if checkpoint is not None:
            model.load_state_dict(checkpoint['state_dict'])
            try:
                args.start_epoch = checkpoint['epoch'] + 1
                best_prec1 = checkpoint['best_prec1']
                optimizer.load_state_dict(checkpoint['optimizer'])
            except KeyError:
                pass

    ### Evaluate directly if required
    print(args)
    if args.evaluate is not None:
        validate(val_loader, model, criterion, args)
        return

    saveID = None
    for epoch in range(args.start_epoch, args.epochs):
        ### Train for one epoch
        tr_prec1, tr_prec5, loss, lr = \
            train(train_loader, model, criterion, optimizer, epoch, args)

        ### Evaluate on validation set
        val_prec1, val_prec5 = validate(val_loader, model, criterion, args)

        ### Remember best prec@1 and save checkpoint
        is_best = val_prec1 >= best_prec1
        best_prec1 = max(val_prec1, best_prec1)

        log = ("Epoch %03d/%03d: top1 %.4f | top5 %.4f" + \
              " | train-top1 %.4f | train-top5 %.4f | loss %.4f | lr %.5f | Time %s\n") \
              % (epoch, args.epochs, val_prec1, val_prec5, tr_prec1, \
              tr_prec5, loss, lr, time.strftime('%Y-%m-%d %H:%M:%S'))
        with open(log_file, 'a') as f:
            f.write(log)

        saveID = save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            },
            epoch,
            args.savedir,
            is_best,
            saveID,
            keep_freq=args.save_freq)

    return
Beispiel #6
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # create model
    
    if 'cifar' in args.arch:
        print "CIFAR Model Fix args.lastout As 8"
        args.lastout += 1
        
    
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = resnext_models[args.arch](pretrained=True, numlayers = args.numlayers,\
                                          expansion = args.xp, x = args.x, d = args.d, \
                                         upgroup = True if args.ug else False, downgroup = True if args.dg else False,\
                                         secord = True if args.secord else False, soadd = args.soadd, \
                                         att = True if args.att else False, lastout = args.lastout, dilpat = args.dp, \
                                         deform = args.df, fixx = args.fixx, sqex = args.sqex , ratt = args.ratt, \
                                         nocompete = args.labelnocompete)
        
    else:
        print("=> creating model '{}'".format(args.arch))
        model = resnext_models[args.arch](numlayers = args.numlayers, \
                                          expansion = args.xp, x = args.x , d = args.d, \
                                         upgroup = True if args.ug else False, downgroup = True if args.dg else False,\
                                         secord = True if args.secord else False, soadd = args.soadd, \
                                         att = True if args.att else False, lastout = args.lastout, dilpat = args.dp,
                                         deform = args.df, fixx = args.fixx , sqex = args.sqex , ratt = args.ratt ,\
                                         nocompete = args.labelnocompete)
        #print("args.df: {}".format(args.df))
    
    
    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    
    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            
            #print type(checkpoint)
            
            model.load_state_dict(checkpoint['state_dict'])
            
            if args.finetune:
                args.start_epoch = 0
                print "start_epoch is ",args.start_epoch
                topfeature = int(args.x * args.d * 8 * args.xp)
                model.fc = nn.Linear(topfeature, args.nclass)
                
                
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            
            # For Fine-tuning
            
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    if args.ds == "dir":
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

        train_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(traindir, transforms.Compose([
                transforms.RandomSizedCrop(args.lastout*32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=True)
        
        if args.evaluate == 2:
            
            val_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(valdir, transforms.Compose([
                    transforms.Scale((args.lastout+args.evalmodnum)*32),
                    transforms.CenterCrop(args.lastout*32),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ])),
                batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
        
        if args.evaluate == 3:
            
            val_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(valdir, transforms.Compose([
                    transforms.Scale((args.lastout+args.evalmodnum)*32),
                    transforms.RandomCrop((args.lastout+args.evalmodnum)*32),
                    transforms.RandomCrop(args.lastout*32),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ])),
                batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
            
        else:
            
            val_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(valdir, transforms.Compose([
                    transforms.Scale((args.lastout+1)*32),
                    transforms.CenterCrop(args.lastout*32),
                    transforms.ToTensor(),
                    normalize,
                ])),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True)
        
    elif args.ds in ["CIFAR10","CIFAR100"]:
        normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            ])
        
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize
            ])
        
        if args.ds == "CIFAR10":
            
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data', train=True, download=True,
                             transform=transform_train),
                             batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data', train=False, transform=transform_test),
                batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
        else:
            
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data', train=True, download=True,
                             transform=transform_train),
                             batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data', train=False, transform=transform_test),
                batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
        
    else:
        print "Unrecognized Dataset. Halt."
        return 0
        
        
    # define loss function (criterion) and pptimizer
    #criterion = nn.CrossEntropyLoss().cuda()
    if 'L1' in args.arch or args.L1 == 1:
        criterion = nn.L1Loss(size_average=True).cuda()
    else:
        criterion = nn.CrossEntropyLoss().cuda()

        
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,nesterov=False if args.nes == 0 else True)
    #optimizer = torch.optim.Adam(model.parameters(), args.lr)
    
    if args.evaluate == 2 :
        NUM_MULTICROP = 2
        for i in range(0,NUM_MULTICROP):
            test_output(val_loader, model, 'Result_{0}_{1}_{2}'.format(args.evaluate, i, args.evalmodnum))
        return
    
    elif args.evaluate == 3 :
        NUM_MULTICROP = 8
        for i in range(0,NUM_MULTICROP):
            # Reset Val_Loader!!
            val_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(valdir, transforms.Compose([
                    transforms.Scale((args.lastout+args.evalmodnum)*32),
                    transforms.RandomCrop((args.lastout+args.evalmodnum)*32),
                    transforms.RandomCrop(args.lastout*32),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ])),
                batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
            # Test
            test_output(val_loader, model, args.evaltardir+'Result_{0}_{1}_{2}'.format(args.evaluate, i, args.evalmodnum))
        return
    
    elif args.evaluate == 1:
        test_output(val_loader, model, 'Result_00')
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        for i in range(args.tl):
            train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best)
        print 'Current best accuracy: ', best_prec1
    print 'Global best accuracy: ', best_prec1
Beispiel #7
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log("Compress Rate: {}".format(args.rate), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            if args.use_state_dict:
                net.load_state_dict(checkpoint['state_dict'])
            else:
                net = checkpoint['state_dict']

            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        time1 = time.time()
        validate(test_loader, net, criterion, log)
        time2 = time.time()
        print('function took %0.3f ms' % ((time2 - time1) * 1000.0))
        return

    comp_rate = args.rate
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("the compression rate now is %f" % comp_rate)

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

    print(" accu before is: %.3f %%" % val_acc_1)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log)

        # evaluate on validation set
        val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_1,
                                  val_acc_1)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net,
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    log.close()
Beispiel #8
0
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)
elif args.dataset == 'cifar100':
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR100('./data.cifar100', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.Pad(4),
                           transforms.RandomCrop(32),
                           transforms.RandomHorizontalFlip(),
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)
else:
    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    def factory(
        self,
        pathname,
        name,
        subset='train',
        idenselect=[],
        download=False,
        transform=None,
    ):
        """Factory dataset
        """

        assert (self._checksubset(subset))
        pathname = os.path.expanduser(pathname)

        # pythorch vision dataset soported

        if name == 'mnist':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = datasets.MNIST(pathname,
                                  train=btrain,
                                  transform=transform,
                                  download=download)
            data.labels = np.array(data.targets)

        elif name == 'fashion':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = datasets.FashionMNIST(pathname,
                                         train=btrain,
                                         transform=transform,
                                         download=download)
            data.labels = np.array(data.targets)

        elif name == 'emnist':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = datasets.EMNIST(pathname,
                                   split='byclass',
                                   train=btrain,
                                   transform=transform,
                                   download=download)
            data.labels = np.array(data.targets)

        elif name == 'cifar10':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = datasets.CIFAR10(pathname,
                                    train=btrain,
                                    transform=transform,
                                    download=download)
            data.labels = np.array(data.targets)

        elif name == 'cifar100':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = datasets.CIFAR100(pathname,
                                     train=btrain,
                                     transform=transform,
                                     download=download)
            data.labels = np.array(data.targets)

        elif name == 'stl10':
            split = 'train' if (subset == 'train') else 'test'
            pathname = create_folder(pathname, name)
            data = datasets.STL10(pathname,
                                  split=split,
                                  transform=transform,
                                  download=download)

        elif name == 'svhn':
            split = 'train' if (subset == 'train') else 'test'
            pathname = create_folder(pathname, name)
            data = datasets.SVHN(pathname,
                                 split=split,
                                 transform=transform,
                                 download=download)
            data.classes = np.unique(data.labels)

        # internet dataset

        elif name == 'cub2011':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = cub2011.CUB2011(pathname, train=btrain, download=download)
            data.labels = np.array(data.targets)

        elif name == 'cars196':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = cars196.Cars196(pathname, train=btrain, download=download)
            data.labels = np.array(data.targets)

        elif name == 'stanford_online_products':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = stanford_online_products.StanfordOnlineProducts(
                pathname, train=btrain, download=download)
            data.labels = np.array(data.targets)
            data.btrain = btrain

        # kaggle dataset
        elif name == 'imaterialist':
            pathname = create_folder(pathname, name)
            data = imaterialist.IMaterialistDatset(pathname, subset, 'jpg')

        # fer recognition datasets

        elif name == 'ferp':
            pathname = create_folder(pathname, name)
            if subset == 'train': subfolder = ferp.train
            elif subset == 'val': subfolder = ferp.valid
            elif subset == 'test': subfolder = ferp.test
            else: assert (False)
            data = ferp.FERPDataset(pathname, subfolder, download=download)

        elif name == 'ck':
            idenselect = np.arange(20) + 0
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = fer.FERClassicDataset(pathname,
                                         'ck',
                                         idenselect=idenselect,
                                         train=btrain)

        elif name == 'ckp':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = fer.FERClassicDataset(pathname,
                                         'ckp',
                                         idenselect=idenselect,
                                         train=btrain)

        elif name == 'jaffe':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = fer.FERClassicDataset(pathname,
                                         'jaffe',
                                         idenselect=idenselect,
                                         train=btrain)

        elif name == 'bu3dfe':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            #idenselect = np.array([0,1,2,3,4,5,6,7,8,9]) + 0
            data = fer.FERClassicDataset(pathname,
                                         'bu3dfe',
                                         idenselect=idenselect,
                                         train=btrain)

        elif name == 'afew':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = afew.Afew(pathname, train=btrain, download=download)
            data.labels = np.array(data.targets)

        elif name == 'celeba':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = celeba.CelebaDataset(pathname,
                                        train=btrain,
                                        download=download)

        elif name == 'ferblack':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = ferfolder.FERFolderDataset(pathname,
                                              train=btrain,
                                              idenselect=idenselect,
                                              download=download)
            data.labels = np.array(data.targets)

        # metric learning dataset

        elif name == 'cub2011metric':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, 'cub2011')
            data = cub2011.CUB2011MetricLearning(pathname,
                                                 train=btrain,
                                                 download=download)
            data.labels = np.array(data.targets)

        elif name == 'cars196metric':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, 'cars196')
            data = cars196.Cars196MetricLearning(pathname,
                                                 train=btrain,
                                                 download=download)
            data.labels = np.array(data.targets)

        else:
            assert (False)

        data.btrain = (subset == 'train')
        return data