def create_model(arch, num_classes, layer):
    if arch.endswith('resnet50'):
        model = Resnet50(0, num_classes, True, layer)
    elif arch.endswith('resnet152_1x1'):
        model = Resnet152_1x1(0, num_classes, True, layer)
    elif arch.endswith('resnet152_1x1lap'):
        model = Resnet152_1x1LAP(0, num_classes, True, layer)
    elif arch.endswith('resnet152_1x1lmp'):
        model = Resnet152_1x1LMP(0, num_classes, True, layer)
    elif arch.endswith('resnet50_1x1lap'):
        model = Resnet50_1x1LAP(0, num_classes, True, layer)
    elif arch.endswith('resnet50_1x1lmp'):
        model = Resnet50_1x1LMP(0, num_classes, True, layer)
    elif arch.endswith('resnet152_truncated'):
        model = Resnet152_truncated(0, num_classes, True, layer)
    elif arch.endswith('resnet50_truncated'):
        model = Resnet50_truncated(0, num_classes, True, layer)
    elif arch.endswith('vgg16'):
        model = VGG16(0, num_classes, True)
    elif arch.endswith('vgg16_1d'):
        model = VGG16_1d(0, num_classes, True, layer)
    elif arch.endswith('vgg16_1x1lmp'):
        model = VGG16_1x1LMP(0, num_classes, True, layer)
    elif arch.endswith('vgg16_1x1lap'):
        model = VGG16_1x1LAP(0, num_classes, True, layer)
    elif arch.endswith('d1_resnet50'):
        model = Resnet50_1d(0, num_classes, True, layer)
    elif arch.endswith('resnet50_1x1'):
        model = Resnet50_1x1(0, num_classes, True, layer)
    elif arch.endswith('d1_resnet152'):
        model = Resnet152_1d(0, num_classes, True, layer)
    elif arch.endswith('mobilenetv1_1x1lmp'):
        model = MobileNetV1_1x1LMP(1-0.999, num_classes, True, layer)
    elif arch.endswith('mobilenetv1_1x1lap'):
        model = MobileNetV1_1x1LAP(1-0.999, num_classes, True, layer)
    elif arch.endswith('mobilenetv2_1x1lmp'):
        model = MobileNetV2_1x1LMP(num_classes, layer)
    elif arch.endswith('mobilenetv2_1x1lap'):
        model = MobileNetV2_1x1LAP(num_classes, layer)
    else:
        raise Exception('arch can only be vgg16 or resnet50!')
    return model
Exemple #2
0



print(num_cpus())
print("sadsadsad")
path = untar_data(URLs.CIFAR, dest="./data/")
# tfms = [rand_resize_crop(224), flip_lr(p=0.5)]
ds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], [*center_crop(32)])
# ds_tfms = None
# n_gpus = 4
data = ImageDataBunch.from_folder(path, valid='test', ds_tfms=ds_tfms,  bs=512, num_workers=6).normalize(cifar_stats)

# learn = Learner(data, resnet50(), metrics=accuracy)

learn = Learner(data, Resnet50(0, 10, True, 99), metrics=[accuracy, top_k_accuracy]).distributed(args.local_rank)
learn.to_fp32()
# learn.model = nn.parallel.DistributedDataParallel(learn.model)
# learn.model = nn.DataParallel(learn.model)
# print(learn.summary())



print('start training...')
learn.fit_one_cycle(35, 3e-3, wd=0.4)


# data = ImageDataBunch.from_folder(path, valid='test', ds_tfms=(tfms, []), bs=512).normalize(cifar_stats)
# ds = data.train_ds
# learn = Learner(data, resnet50(), metrics=accuracy).to_fp16()
# learn.fit_one_cycle(30, 3e-3, wd=0.4, div_factor=10, pct_start=0.5)
Exemple #3
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc

    mem = os.popen(
        '"nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
    ).read().split('\n')
    total = mem[0].split(',')[0]
    total = int(total)
    max_mem = int(total * 0.5)
    # x = torch.rand((256, 1024, max_mem)).cuda()
    # del x

    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        torch.distributed.init_process_group(backend=args.dist_backend,
                                             init_method=args.dist_url,
                                             world_size=args.world_size,
                                             rank=args.rank)

    num_classes = 1000

    # Model
    print("==> creating model '{}'".format(args.arch))
    if args.arch.endswith('resnet50'):
        model = Resnet50(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet50_shuffle'):
        model = Resnet50_Shuffle(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet152_1x1'):
        model = Resnet152_1x1(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet152_1x1lap'):
        model = Resnet152_1x1LAP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet152_1x1lmp'):
        model = Resnet152_1x1LMP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet50_1x1lap'):
        model = Resnet50_1x1LAP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet50_1x1lmp'):
        model = Resnet50_1x1LMP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet152_shuffle'):
        model = Resnet152_Shuffle(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet152_truncated'):
        model = Resnet152_truncated(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet50_truncated'):
        model = Resnet50_truncated(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('vgg16_shuffle'):
        model = VGG16_Shuffle(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('vgg16_rand'):
        model = VGG16_Rand(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('vgg16_1d'):
        model = VGG16_1d(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('vgg16_1x1lmp'):
        model = VGG16_1x1LMP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('vgg16_1x1lap'):
        model = VGG16_1x1LAP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('d1_resnet50'):
        model = Resnet50_1d(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet50_1x1'):
        model = Resnet50_1x1(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('d1_resnet152'):
        model = Resnet152_1d(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('mobilenetv1_1x1lmp'):
        model = MobileNetV1_1x1LMP(1 - 0.999, num_classes, True, args.layer)
    elif args.arch.endswith('mobilenetv1_1x1lap'):
        model = MobileNetV1_1x1LAP(1 - 0.999, num_classes, True, args.layer)
    elif args.arch.endswith('mobilenetv2_1x1lmp'):
        model = MobileNetV2_1x1LMP(num_classes, args.layer)
    elif args.arch.endswith('mobilenetv2_1x1lap'):
        model = MobileNetV2_1x1LAP(num_classes, args.layer)
    else:
        raise Exception('arch can only be vgg16 or resnet50!')

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.train_batch = int(args.train_batch / ngpus_per_node)
            args.test_batch = int(args.test_batch / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            warnings.warn(
                'NOT DISTRIBUTED!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        warnings.warn(
            'NOT DISTRIBUTED!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        warnings.warn(
            'NOT DISTRIBUTED!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    # optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=args.weight_decay)

    cudnn.benchmark = True

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))
    # for name, param in model.named_parameters():
    #     print(name)
    # for name in model.named_modules():
    #     print(name)
    # for param in model.parameters():
    #     print(param)

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    if args.dataset == 'xian':
        print('ImageNet from Xian is used!')
        traindir = '/BS/xian/work/data/imageNet1K/train/'
        valdir = '/BS/database11/ILSVRC2012/val/'
    else:
        traindir = os.path.join(args.dataset, 'train')
        valdir = os.path.join(args.dataset, 'val')

    trainset = datasets.ImageFolder(traindir, transform_train)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            trainset)
    else:
        train_sampler = None
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.train_batch,
                                  shuffle=(train_sampler is None),
                                  num_workers=args.workers,
                                  pin_memory=True,
                                  sampler=train_sampler)

    testset = datasets.ImageFolder(valdir, transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 pin_memory=True)

    # Resume
    title = 'imagenet-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
            'Valid Acc.'
        ])

    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(testloader, model, criterion, args)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))

        train_loss, train_acc = train(trainloader, model, criterion, optimizer,
                                      epoch, args)
        test_loss, test_acc = test(testloader, model, criterion, args)

        # append logger file
        logger.append(
            [state['lr'], train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'acc': test_acc,
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')
    print(best_acc)
def main():
    global best_acc
    num_classes = 1000

    # Model
    print("==> creating model '{}'".format(args.arch))
    if args.arch.endswith('resnet50'):
        model = Resnet50(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet152_1x1'):
        model = Resnet152_1x1(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet152_1x1lap'):
        model = Resnet152_1x1LAP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet152_1x1lmp'):
        model = Resnet152_1x1LMP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet50_1x1lap'):
        model = Resnet50_1x1LAP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet50_1x1lmp'):
        model = Resnet50_1x1LMP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet152_truncated'):
        model = Resnet152_truncated(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet50_truncated'):
        model = Resnet50_truncated(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('vgg16'):
        model = VGG16(args.drop, num_classes, True)
    elif args.arch.endswith('vgg16_1d'):
        model = VGG16_1d(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('vgg16_1x1lmp'):
        model = VGG16_1x1LMP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('vgg16_1x1lap'):
        model = VGG16_1x1LAP(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('d1_resnet50'):
        model = Resnet50_1d(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('resnet50_1x1'):
        model = Resnet50_1x1(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('d1_resnet152'):
        model = Resnet152_1d(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('mobilenetv1_1x1lmp'):
        model = MobileNetV1_1x1LMP(1 - 0.999, num_classes, True, args.layer)
    elif args.arch.endswith('mobilenetv1_1x1lap'):
        model = MobileNetV1_1x1LAP(1 - 0.999, num_classes, True, args.layer)
    elif args.arch.endswith('mobilenetv2_1x1lmp'):
        model = MobileNetV2_1x1LMP(num_classes, args.layer)
    elif args.arch.endswith('mobilenetv2_1x1lap'):
        model = MobileNetV2_1x1LAP(num_classes, args.layer)
    else:
        raise Exception('arch can only be vgg16 or resnet50!')

    model = torch.nn.DataParallel(model).cuda()
    criterion = nn.CrossEntropyLoss().cuda(args.gpu_id)
    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
    args.checkpoint = os.path.dirname(args.resume)
    checkpoint = torch.load(args.resume)
    best_acc = checkpoint['best_acc']
    model.load_state_dict(checkpoint['state_dict'])

    print('==> Preparing dataset %s' % args.dataset)
    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        # TODO: remember to get rid of it
        # Rotation(-30),
        # Translation((-60./224., 0)),
        transforms.ToTensor(),
        # Center_block(0.5),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        # transforms.Normalize((124/255, 116/255, 104/255), (0.229, 0.224, 0.225)), # TODO: for meanvalue_background val
    ])

    # valdir = os.path.join(args.dataset, 'val')
    valdir = args.dataset
    testset = datasets.ImageFolder(valdir, transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 pin_memory=True)

    print(valdir)
    test_loss, test_acc = test(testloader, model, criterion, args)
    print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
    return
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    num_classes = 1000

    # Model
    print("==> creating model '{}'".format(args.arch))
    if args.arch.startswith('resnet50'):
        model = Resnet50(args.drop, num_classes, True, args.layer)
    elif args.arch.endswith('vgg16'):
        model = VGG16(args.drop, num_classes, True)
    elif args.arch.endswith('vgg16_1d'):
        model = VGG16_1d(args.drop, num_classes, True, args.layer)
    elif args.arch.startswith('d1_resnet50'):
        model = Resnet50_1d(args.drop, num_classes, True, args.layer)
    else:
        raise Exception('arch can only be vgg16 or resnet50!')

    # DataParallel will divide and allocate batch_size to all available GPUs
    model = torch.nn.DataParallel(model).cuda()

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    # optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=args.weight_decay)

    cudnn.benchmark = True

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    traindir = os.path.join(args.dataset, 'train')
    valdir = os.path.join(args.dataset, 'val')

    trainset = datasets.ImageFolder(traindir, transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.train_batch,
                                  shuffle=True,
                                  num_workers=args.workers,
                                  pin_memory=True)

    testset = datasets.ImageFolder(valdir, transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 pin_memory=True)

    # Resume
    title = 'imagenet-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
            'Valid Acc.'
        ])

    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(testloader, model, criterion, use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))

        train_loss, train_acc = train(trainloader, model, criterion, optimizer,
                                      epoch, use_cuda)
        test_loss, test_acc = test(testloader, model, criterion, use_cuda)

        # append logger file
        logger.append(
            [state['lr'], train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'acc': test_acc,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')
    print(best_acc)
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    print("==> creating model '{}'".format(args.arch))
    if args.arch.startswith('resnet50'):
        model = Resnet50(0, 1000, True, 99)
    elif args.arch.endswith('vgg16'):
        model = VGG16(0, 1000, True)
    else:
        raise Exception('arch can only be vgg16 or resnet50!')

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val_orig')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)