Ejemplo n.º 1
0
    def __init__(self, args):
        super(Dataloader, self).__init__()
        self.args = args

        self.dataset_test_name = args.dataset_test
        self.dataset_train_name = args.dataset_train
        self.dataroot = args.dataroot
        self.batch_size = args.batch_size

        if self.dataset_train_name == "CELEBA":
            self.dataset_train, self.dataset_train_len = datasets.ImageFolder(
                root=self.dataroot + "/train")

        elif self.dataset_train_name == "MNIST":
            self.dataset_train, self.dataset_train_len = datasets.MNIST(
                self.dataroot).train()

        else:
            raise (Exception("Unknown Dataset"))

        if self.dataset_test_name == "CELEBA":
            self.dataset_test, self.dataset_test_len = datasets.ImageFolder(
                root=self.dataroot + "/test")

        elif self.dataset_test_name == "MNIST":
            self.dataset_test, self.dataset_test_len = datasets.MNIST(
                self.dataroot).test()

        else:
            raise (Exception("Unknown Dataset"))
Ejemplo n.º 2
0
def train(traindir, sz, min_scale=0.08, shuffle_seed=0):
    train_tfms = [
        transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
        transforms.RandomHorizontalFlip()
    ]
    train_dataset = datasets.ImageFolder(traindir,
                                         transforms.Compose(train_tfms))
    return PaddleDataLoader(train_dataset, shuffle_seed=shuffle_seed).reader()
Ejemplo n.º 3
0
def sort_ar(valdir):
    idx2ar_file = valdir + '/../sorted_idxar.p'
    if os.path.isfile(idx2ar_file):
        return pickle.load(open(idx2ar_file, 'rb'))
    print(
        'Creating AR indexes. Please be patient this may take a couple minutes...'
    )
    val_dataset = datasets.ImageFolder(
        valdir)  # AS: TODO: use Image.open instead of looping through dataset
    sizes = [img[0].size for img in val_dataset]
    idx_ar = [(i, round(s[0] * 1.0 / s[1], 5)) for i, s in enumerate(sizes)]
    sorted_idxar = sorted(idx_ar, key=lambda x: x[1])
    pickle.dump(sorted_idxar, open(idx2ar_file, 'wb'))
    print('Done')
    return sorted_idxar
Ejemplo n.º 4
0
def test(valdir, bs, sz, rect_val=False):
    if rect_val:
        idx_ar_sorted = sort_ar(valdir)
        idx_sorted, _ = zip(*idx_ar_sorted)
        idx2ar = map_idx2ar(idx_ar_sorted, bs)

        ar_tfms = [transforms.Resize(int(sz * 1.14)), CropArTfm(idx2ar, sz)]
        val_dataset = ValDataset(valdir, transform=ar_tfms)
        return PaddleDataLoader(val_dataset,
                                concurrent=1,
                                indices=idx_sorted,
                                shuffle=False).reader()

    val_tfms = [transforms.Resize(int(sz * 1.14)), transforms.CenterCrop(sz)]
    val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms))

    return PaddleDataLoader(val_dataset).reader()
Ejemplo n.º 5
0
    def __init__(self, args):
        self.args = args

        self.loader_input = args.loader_input
        self.loader_label = args.loader_label

        self.split_test = args.split_test
        self.split_train = args.split_train
        self.dataset_test_name = args.dataset_test
        self.dataset_train_name = args.dataset_train
        self.resolution = (args.resolution_wide, args.resolution_high)

        self.input_filename_test = args.input_filename_test
        self.label_filename_test = args.label_filename_test
        self.input_filename_train = args.input_filename_train
        self.label_filename_train = args.label_filename_train

        if self.dataset_train_name == 'LSUN':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                db_path=args.dataroot,
                classes=['bedroom_train'],
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'CIFAR10' or self.dataset_train_name == 'CIFAR100':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataroot,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.RandomCrop(self.resolution, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010)),
                ]))

        elif self.dataset_train_name == 'MYCIFAR10' or self.dataset_train_name == 'MYCIFAR100':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataroot,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.RandomCrop(self.resolution, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010)),
                ]))

        elif self.dataset_train_name == 'CocoCaption' or self.dataset_train_name == 'CocoDetection':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataroot,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'STL10' or self.dataset_train_name == 'SVHN':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataroot,
                split='train',
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'MNIST':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataroot,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307, ), (0.3081, ))
                ]))

        elif self.dataset_train_name == 'ImageNet':
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            self.dataset_train = datasets.ImageFolder(
                root=self.args.dataroot + self.args.input_filename_train,
                transform=transforms.Compose([
                    transforms.RandomSizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]))

        elif self.dataset_train_name == 'FRGC':
            self.dataset_train = datasets.ImageFolder(
                root=self.args.dataroot + self.args.input_filename_train,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'Folder':
            self.dataset_train = datasets.ImageFolder(
                root=self.args.dataroot + self.args.input_filename_train,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'FileListLoader':
            self.dataset_train = datasets.FileListLoader(
                self.input_filename_train,
                self.label_filename_train,
                self.split_train,
                self.split_test,
                train=True,
                transform_train=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                transform_test=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                loader_input=self.loader_input,
                loader_label=self.loader_label,
            )

        elif self.dataset_train_name == 'FolderListLoader':
            self.dataset_train = datasets.FileListLoader(
                self.input_filename_train,
                self.label_filename_train,
                self.split_train,
                self.split_test,
                train=True,
                transform_train=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                transform_test=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                loader_input=self.loader_input,
                loader_label=self.loader_label,
            )

        else:
            raise (Exception("Unknown Dataset"))

        if self.dataset_test_name == 'LSUN':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                db_path=args.dataroot,
                classes=['bedroom_val'],
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'CIFAR10' or self.dataset_test_name == 'CIFAR100':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataroot,
                train=False,
                download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010)),
                ]))

        elif self.dataset_test_name == 'MYCIFAR10' or self.dataset_test_name == 'MYCIFAR100':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataroot,
                train=False,
                download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010)),
                ]))

        elif self.dataset_test_name == 'CocoCaption' or self.dataset_test_name == 'CocoDetection':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataroot,
                train=False,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'STL10' or self.dataset_test_name == 'SVHN':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataroot,
                split='test',
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'MNIST':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataroot,
                train=False,
                download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307, ), (0.3081, ))
                ]))

        elif self.dataset_test_name == 'ImageNet':
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            self.dataset_test = datasets.ImageFolder(
                root=self.args.dataroot + self.args.input_filename_test,
                transform=transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ]))

        elif self.dataset_test_name == 'FRGC':
            self.dataset_test = datasets.ImageFolder(
                root=self.args.dataroot + self.args.input_filename_test,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'Folder':
            self.dataset_test = datasets.ImageFolder(
                root=self.args.dataroot + self.args.input_filename_test,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'FileListLoader':
            self.dataset_test = datasets.FileListLoader(
                self.input_filename_test,
                self.label_filename_test,
                self.split_train,
                self.split_test,
                train=True,
                transform_train=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                loader_input=self.loader_input,
                loader_label=self.loader_label,
            )

        elif self.dataset_test_name == 'FolderListLoader':
            self.dataset_test = datasets.FileListLoader(
                self.input_filename_test,
                self.label_filename_test,
                self.split_train,
                self.split_test,
                train=True,
                transform_train=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                loader_input=self.loader_input,
                loader_label=self.loader_label,
            )
        else:
            raise (Exception("Unknown Dataset"))
#        transforms.Resize(256),
#        transforms.CenterCrop(224),
#        transforms.ToTensor(),
#        normalize,
#    ]),Train = False),
#    batch_size=args.test_batch_size, shuffle=False,
#    num_workers=args.workers, pin_memory=True)
input_size = 224
normalize = transforms.Normalize(meanfile=args.data +
                                 '/imagenet_mean.binaryproto')

train_dataset = datasets.ImageFolder(
    args.data,
    transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
        transforms.RandomSizedCrop(input_size),
    ]),
    Train=True)

#if args.distributed:
#    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
#else:
train_sampler = None

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=False,
                                           num_workers=args.workers,
                                           pin_memory=True,
Ejemplo n.º 7
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if platform.system() == "Windows":
        args.nocuda = True
    else:
        args.nocuda = False

    # create model
    if args.arch == 'alexnet':
        model = model_list.alexnet(pretrained=args.pretrained,
                                   base_number=args.base_number)
        input_size = 227
    else:
        raise Exception('Model not supported yet')

    model.features = torch.nn.DataParallel(model.features)
    if not args.nocuda:
        # set the seed
        torch.manual_seed(1)
        torch.cuda.manual_seed(1)
        model.cuda()
        # define loss function (criterion) and optimizer
        criterion = nn.CrossEntropyLoss().cuda()
        # Set benchmark
        cudnn.benchmark = True
    else:
        criterion = nn.CrossEntropyLoss()

    global optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 weight_decay=args.weight_decay)
    # random initialization
    if not args.pretrained:
        for m in model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                c = float(m.weight.data[0].nelement())
                m.weight.data = m.weight.data.normal_(0, 1.0 / c)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data = m.weight.data.zero_().add(1.0)
    else:
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.weight.data = m.weight.data.zero_().add(1.0)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            # original saved file with DataParallel
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            print(checkpoint)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # # Data loading code

    # if you want to use pre-prosecess in used in caffe:
    # transform = transforms.Compose([
    #     transforms.Resize((256, 256)),
    #     transforms.RandomResizedCrop(input_size),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor(),
    #     transforms.Lambda(lambda x: x * 255),
    #     transforms.Lambda(lambda x: torch.cat(reversed(torch.split(x, 1, 0)))),
    #     transforms.Lambda(lambda x: x - torch.Tensor([103.939, 116.779, 123.68]).view(3, 1, 1).expand(3, 227, 227))
    # ])
    # transform_val = transforms.Compose([
    #     transforms.Resize((256, 256)),
    #     transforms.CenterCrop(input_size),
    #     transforms.ToTensor(),
    #     transforms.Lambda(lambda x: x * 255),
    #     transforms.Lambda(lambda x: torch.cat(reversed(torch.split(x, 1, 0)))),
    #     transforms.Lambda(lambda x: x - torch.Tensor([103.939, 116.779, 123.68]).view(3, 1, 1).expand(3, 227, 227))
    # ])

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    transform_val = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    traindir = os.path.join(args.data, 'ILSVRC2012_img_train')
    valdir = os.path.join(args.data, 'ILSVRC2012_img_val')
    train_dataset = datasets.ImageFolder(traindir,
                                         transform,
                                         mapfile=os.path.join(
                                             args.data,
                                             "ImageNet12_train.txt"))
    val_dataset = datasets.ImageFolder(valdir,
                                       transform_val,
                                       mapfile=os.path.join(
                                           args.data, "ImageNet12_val.txt"))

    if not args.nocuda:
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    else:
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers)

        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers)

    print(model)

    # define the binarization operator
    global bin_op
    bin_op = util.BinOp(model)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Ejemplo n.º 8
0
def main():
    SIZE_IMG = 224
    # 224
    global args
    args = parser.parse_args()

    # create model
    if args.arch.startswith('resnet50'):
        model = model_defs.resnet50_oneway(num_classes=2)

    model = nn.DataParallel(model.model)
    # dirty trick

    # open log file
    if args.train == 1:
        log_dir = 'logs'
        log_name = args.arch + '_new.csv'
        if not os.path.isdir(log_dir):
            os.mkdir(log_dir)
        log_handle = get_file_handle(os.path.join(log_dir, log_name), 'wb+')
        log_handle.write('Epoch, LearningRate, Momentum, WeightDecay,' + \
                        'Loss, Precision, Recall, Accuracy(IoU), FgWeight, BgWeight\n')
        log_handle.close()

    # check model directory
    model_dir = 'models'
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    # resume learning based on cmdline arguments
    if ((args.start_epoch > 1) and (args.train == 1)):
        load_epoch = args.start_epoch - 1
    elif (args.train == 0):
        load_epoch = args.load_epoch
    else:
        load_epoch = 0

    if load_epoch > 0:
        print("=> loading checkpoint for epoch = '{}'".format(load_epoch))
        checkpoint_name = args.arch + '_ep_' + str(load_epoch) + '.pth.tar'
        checkpoint = torch.load(os.path.join(model_dir, checkpoint_name))
        model.load_state_dict(checkpoint['state_dict'])

#    model = add_dropout2d(model);

    model.cuda()
    # transfer to cuda

    print(model)

    mean = load_pickle('./mean')
    std = load_pickle('./std')

    if args.train == 1:

        train_data_dir, train_gt_dir = args.data, args.gt
        train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            train_data_dir,
            train_gt_dir,
            transform_joint=transforms.Compose_Joint([
                transforms.RandomCrop(SIZE_IMG),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
            ]),
            transform=transforms.Compose([
                transforms.ColorJitter(0.3, 0.3, 0.3, 0),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ]),
            target_transform=transforms.Compose([
                transforms.ToTensorTarget(),
            ]),
            do_copy=True,
        ),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        weights = torch.from_numpy(np.array([1., 1.01])).float()
        criterion = nn.CrossEntropyLoss(weights).cuda()

        if args.optim == 'adam':
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.learning_rate,
                                         weight_decay=args.weight_decay)
        elif args.optim == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.learning_rate,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)

        for epoch in range(args.start_epoch, args.end_epoch + 1):

            # train for one epoch
            stats_epoch = train(train_loader, model, criterion, optimizer,
                                epoch)

            model_name = args.arch + '_ep_' + str(epoch) + '.pth.tar'
            # get current parameters of optimizer
            for param_group in optimizer.param_groups:
                cur_lr = param_group['lr']
                cur_wd = param_group['weight_decay']
                if param_group.has_key('momentum'):
                    cur_momentum = param_group['momentum']
                else:
                    cur_momentum = 'n/a'
                break
                # constant parameters throughout the network

            if epoch % args.save_interval == 0:
                state = {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'learning_rate': cur_lr,
                    'moemntum': cur_momentum,
                    'weight_decay': cur_wd,
                    'fg_weight': weights[1],
                    'bg_weight': weights[0],
                }

                torch.save(state, os.path.join(model_dir, model_name))

            # write logs using logHandle
            log_handle = get_file_handle(os.path.join(log_dir, log_name), 'ab')
            log_handle.write(
                str(epoch) + ',' + str(cur_lr) + ',' + str(cur_momentum) +
                ',' + str(cur_wd) + ',' + str(stats_epoch['loss']) + ',' +
                str(stats_epoch['prec']) + ',' + str(stats_epoch['recall']) +
                ',' + str(stats_epoch['acc']) + ',' + str(weights[1]) + ',' +
                str(weights[0]) + '\n')

            log_handle.close()

#            adjust_learning_rate(optimizer, epoch, 10); # adjust learning rate

    elif args.train == 0:  # test
        testdir = args.data
        outdir = args.out
        stride = args.test_stride
        test_batch_size = args.test_batch_size

        test_transformer = transforms.Compose([
            #            transforms.RandomHorizontalFlip(),
            #            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ])
        #        test(testdir, outdir, test_transformer, model, load_epoch, stride, SIZE_IMG);
        test_batch_form(testdir, outdir, test_transformer, model, load_epoch,
                        stride, SIZE_IMG, test_batch_size)
Ejemplo n.º 9
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)
    ''' random seed '''
    if args.seed is not None:
        random.seed(args.seed)
    else:
        args.seed = random.randint(1, 10000)

    torch.manual_seed(args.seed)
    cudnn.deterministic = True
    print('==> random seed:', args.seed)

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)
    ''' data load info '''
    data_info = h5py.File(os.path.join('./data', args.data, 'data_info.h5'),
                          'r')
    img_path = str(data_info['img_path'][...]).replace("b'",
                                                       '').replace("'", '')
    args.c_att = torch.from_numpy(data_info['coarse_att'][...]).cuda()
    args.f_att = torch.from_numpy(data_info['fine_att'][...]).cuda()
    args.trans_map = torch.from_numpy(data_info['trans_map'][...]).cuda()
    args.num_classes, args.sf_size = args.c_att.size()
    ''' model building '''
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        best_prec1 = 0
        model, criterion = models.__dict__[args.arch](pretrained=True,
                                                      args=args)
    else:
        print("=> creating model '{}'".format(args.arch))
        model, criterion = models.__dict__[args.arch](args=args)
    print("=> is the backbone fixed: '{}'".format(args.is_fix))

    if args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    criterion = criterion.cuda(args.gpu)
    ''' optimizer '''
    pse_params = [v for k, v in model.named_parameters() if 'ste' not in k]
    ste_params = [v for k, v in model.named_parameters() if 'ste' in k]

    pse_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            pse_params),
                                     args.lr,
                                     betas=(0.5, 0.999),
                                     weight_decay=args.weight_decay)
    ste_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            ste_params),
                                     args.lr,
                                     betas=(0.5, 0.999),
                                     weight_decay=args.weight_decay)
    ''' optionally resume from a checkpoint'''
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #args.start_epoch = checkpoint['epoch']
            if (best_prec1 == 0):
                best_prec1 = checkpoint['best_prec1']
            print('=> pretrained acc {:.4F}'.format(best_prec1))
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join('./data', args.data, 'train.list')
    valdir = os.path.join('./data', args.data, 'test.list')

    train_transforms, val_transforms = preprocess_strategy(args.data)

    train_dataset = datasets.ImageFolder(img_path, traindir, train_transforms)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        img_path, valdir, val_transforms),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # evaluate on validation set
    prec1 = validate(val_loader, model, criterion)
Ejemplo n.º 10
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.arch == 'alexnet':
        model = alexnet.alexnet(pretrained=args.pretrained)
        input_size = 227
    else:
        raise Exception('Model not supported yet')

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    if not os.path.exists(args.data + '/imagenet_mean.binaryproto'):
        print("==> Data directory" + args.data + "does not exits")
        print("==> Please specify the correct data path by")
        print("==>     --data <DATA_PATH>")
        return

    normalize = transforms.Normalize(meanfile=args.data +
                                     '/imagenet_mean.binaryproto')

    train_dataset = datasets.ImageFolder(
        args.data,
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            transforms.RandomSizedCrop(input_size),
        ]),
        Train=True)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        args.data,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
            transforms.CenterCrop(input_size),
        ]),
        Train=False),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    print(model)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Ejemplo n.º 11
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)
    ''' save path '''
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    ''' random seed '''
    if args.seed is not None:
        random.seed(args.seed)
    else:
        args.seed = random.randint(1, 10000)

    torch.manual_seed(args.seed)
    cudnn.deterministic = True
    print('==> random seed:', args.seed)

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)
    ''' data load info '''
    data_info = h5py.File(os.path.join('./data', args.data, 'data_info.h5'),
                          'r')
    img_path = str(data_info['img_path'][...]).replace("b'",
                                                       '').replace("'", '')
    args.coarse_att = torch.from_numpy(data_info['coarse_att'][...]).cuda()
    args.fine_att = torch.from_numpy(data_info['fine_att'][...]).cuda()
    args.trans_map = torch.from_numpy(data_info['trans_map'][...]).cuda()
    args.num_classes, args.sf_size = args.coarse_att.size()

    if len(range(torch.cuda.device_count())) > 1:
        args.coarse_att = torch.cat([args.coarse_att, args.coarse_att], dim=0)
        args.fine_att = torch.cat([args.fine_att, args.fine_att], dim=0)

    # adj
    coarse_adj = adj_matrix(data_info['coarse_att'][...], 5)
    args.coarse_adj = torch.from_numpy(coarse_adj).cuda()
    fine_att = adj_matrix(data_info['fine_att'][...], 5)
    args.fine_adj = torch.from_numpy(fine_att).cuda()
    ''' model building '''
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        best_prec1 = 0
        model, criterion = models.__dict__[args.arch](pretrained=True,
                                                      args=args)
    else:
        print("=> creating model '{}'".format(args.arch))
        model, criterion = models.__dict__[args.arch](args=args)
    print("=> is the backbone fixed: '{}'".format(args.is_fix))

    if args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    criterion = criterion.cuda(args.gpu)
    ''' optimizer '''
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 args.lr,
                                 betas=(0.5, 0.999),
                                 weight_decay=args.weight_decay)
    ''' optionally resume from a checkpoint'''
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #args.start_epoch = checkpoint['epoch']
            if (best_prec1 == 0):
                best_prec1 = checkpoint['best_prec1']
            print('=> pretrained acc {:.4F}'.format(best_prec1))
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join('./data', args.data, 'train.list')
    valdir = os.path.join('./data', args.data, 'test.list')

    train_transforms, val_transforms = preprocess_strategy(args.data)

    train_dataset = datasets.ImageFolder(img_path, traindir, train_transforms)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        img_path, valdir, val_transforms),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              is_fix=args.is_fix)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        # save model
        if args.is_fix:
            save_path = os.path.join(args.save_path, 'fix.model')
        else:
            save_path = os.path.join(
                args.save_path,
                args.arch + ('_{:.4f}.model').format(best_prec1))
        if is_best:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    #'optimizer' : optimizer.state_dict(),
                },
                filename=save_path)
            print('saving!!!!')
Ejemplo n.º 12
0
def fetch_dataset(data_name, subset):
    dataset = {}
    print('fetching data {}...'.format(data_name))
    root = './data/{}'.format(data_name)
    if data_name in ['MNIST', 'FashionMNIST', 'SVHN']:
        dataset['train'] = eval(
            'datasets.{}(root=root, split=\'train\', subset=subset,'
            'transform=datasets.Compose(['
            'transforms.ToTensor()]))'.format(data_name))
        dataset['test'] = eval(
            'datasets.{}(root=root, split=\'test\', subset=subset,'
            'transform=datasets.Compose([transforms.ToTensor()]))'.format(
                data_name))
        config.PARAM['transform'] = {
            'train':
            datasets.Compose(
                [transforms.Resize((32, 32)),
                 transforms.ToTensor()]),
            'test':
            datasets.Compose(
                [transforms.Resize((32, 32)),
                 transforms.ToTensor()])
        }
    elif data_name == 'EMNIST':
        dataset['train'] = datasets.EMNIST(root=root,
                                           split='train',
                                           subset=subset,
                                           transform=datasets.Compose(
                                               [transforms.ToTensor()]))
        dataset['test'] = datasets.EMNIST(root=root,
                                          split='test',
                                          subset=subset,
                                          transform=datasets.Compose(
                                              [transforms.ToTensor()]))
        config.PARAM['transform'] = {
            'train': datasets.Compose([transforms.ToTensor()]),
            'test': datasets.Compose([transforms.ToTensor()])
        }
    elif data_name in ['CIFAR10', 'CIFAR100']:
        dataset['train'] = eval(
            'datasets.{}(root=root, split=\'train\', subset=subset,'
            'transform=datasets.Compose(['
            'transforms.ToTensor()]))'.format(data_name))
        dataset['test'] = eval(
            'datasets.{}(root=root, split=\'test\', subset=subset,'
            'transform=datasets.Compose([transforms.ToTensor()]))'.format(
                data_name))
        config.PARAM['transform'] = {
            'train': datasets.Compose([transforms.ToTensor()]),
            'test': datasets.Compose([transforms.ToTensor()])
        }
    elif data_name == 'ImageNet':
        dataset['train'] = datasets.ImageNet(root,
                                             split='train',
                                             subset=subset,
                                             transform=datasets.Compose(
                                                 [transforms.ToTensor()]))
        dataset['test'] = datasets.ImageNet(root,
                                            split='test',
                                            subset=subset,
                                            transform=datasets.Compose(
                                                [transforms.ToTensor()]))
        config.PARAM['transform'] = {
            'train':
            datasets.Compose(
                [transforms.Resize((224, 224)),
                 transforms.ToTensor()]),
            'test':
            datasets.Compose(
                [transforms.Resize((224, 224)),
                 transforms.ToTensor()])
        }
    elif data_name == 'Kodak':
        dataset['train'] = datasets.ImageFolder(root,
                                                transform=datasets.Compose(
                                                    [transforms.ToTensor()]))
        dataset['test'] = datasets.ImageFolder(root,
                                               transform=datasets.Compose(
                                                   [transforms.ToTensor()]))
        config.PARAM['transform'] = {
            'train': datasets.Compose([transforms.ToTensor()]),
            'test': datasets.Compose([transforms.ToTensor()])
        }
    else:
        raise ValueError('Not valid dataset name')
    dataset['train'].transform = config.PARAM['transform']['train']
    dataset['test'].transform = config.PARAM['transform']['test']
    print('data ready')
    return dataset
Ejemplo n.º 13
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # create model
    if args.arch == 'alexnet':
        model = model_list.alexnet(pretrained=args.pretrained)
        input_size = 227
    else:
        raise Exception('Model not supported yet')

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 weight_decay=args.weight_decay)

    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            c = float(m.weight.data[0].nelement())
            m.weight.data = m.weight.data.normal_(0, 1.0 / c)
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data = m.weight.data.zero_().add(1.0)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code

    if args.caffe_data:
        print('==> Using Caffe Dataset')
        cwd = os.getcwd()
        sys.path.append(cwd + '/../')
        import datasets as datasets
        import datasets.transforms as transforms
        if not os.path.exists(args.data + '/imagenet_mean.binaryproto'):
            print("==> Data directory" + args.data + "does not exits")
            print("==> Please specify the correct data path by")
            print("==>     --data <DATA_PATH>")
            return

        normalize = transforms.Normalize(meanfile=args.data +
                                         '/imagenet_mean.binaryproto')

        train_dataset = datasets.ImageFolder(
            args.data,
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
                transforms.RandomSizedCrop(input_size),
            ]),
            Train=True)

        train_sampler = None

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            args.data,
            transforms.Compose([
                transforms.ToTensor(),
                normalize,
                transforms.CenterCrop(input_size),
            ]),
            Train=False),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    else:
        print('==> Using Pytorch Dataset')
        import torchvision
        import torchvision.transforms as transforms
        import torchvision.datasets as datasets
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        torchvision.set_image_backend('accimage')

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(input_size, scale=(0.40, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(input_size),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

    print model

    # define the binarization operator
    global bin_op
    bin_op = util.BinOp(model)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Ejemplo n.º 14
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)
    ''' save path '''
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    ''' random seed '''
    if args.seed is not None:
        random.seed(args.seed)
    else:
        args.seed = random.randint(1, 10000)

    torch.manual_seed(args.seed)
    cudnn.deterministic = True
    print('==> random seed:', args.seed)

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)
    ''' data load info '''
    data_info = h5py.File(os.path.join('./data', args.data, 'data_info.h5'),
                          'r')
    nc = data_info['all_att'][...].shape[0]
    sf_size = data_info['all_att'][...].shape[1]
    semantic_data = {
        'seen_class': data_info['seen_class'][...],
        'unseen_class': data_info['unseen_class'][...],
        'all_class': np.arange(nc),
        'all_att': data_info['all_att'][...]
    }
    ''' load semantic data'''
    args.num_classes = nc
    args.sf_size = sf_size
    args.sf = semantic_data['all_att']
    # adj
    adj = adj_matrix(nc)
    args.adj = adj
    ''' model building '''
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        best_prec1 = 0
        model, criterion = models.__dict__[args.arch](pretrained=True,
                                                      args=args)
    else:
        print("=> creating model '{}'".format(args.arch))
        model, criterion = models.__dict__[args.arch](args=args)
    print("=> is the backbone fixed: '{}'".format(args.is_fix))

    if args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    criterion = criterion.cuda(args.gpu)
    ''' optimizer '''
    odr_params = [v for k, v in model.named_parameters() if 'odr_' in k]
    zsr_params = [v for k, v in model.named_parameters() if 'zsr_' in k]

    odr_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           odr_params),
                                    args.lr1,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    zsr_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            zsr_params),
                                     args.lr2,
                                     betas=(0.5, 0.999),
                                     weight_decay=args.weight_decay)

    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       model.parameters()),
                                args.lr1,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    ''' optionally resume from a checkpoint'''
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #args.start_epoch = checkpoint['epoch']
            if (best_prec1 == 0):
                best_prec1 = checkpoint['best_prec1']
            print('=> pretrained acc {:.4F}'.format(best_prec1))
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    if args.data.lower() == 'cub':
        img_path = '/data/cq14/CUB/CUB_200_2011/images/'
    elif args.data.lower() == 'awa2':
        img_path = '~~~/Animals_with_Attributes2/JPEGImages/'
    elif args.data.lower() == 'sun':
        img_path = 'xxxx'
    elif args.data.lower() == 'apy':
        img_path = 'xxxx'

    traindir = os.path.join('./data', args.data, 'train.list')
    valdir1 = os.path.join('./data', args.data, 'test_seen.list')
    valdir2 = os.path.join('./data', args.data, 'test_unseen.list')

    train_transforms, val_transforms = preprocess_strategy(args.data, args)

    train_dataset = datasets.ImageFolder(img_path, traindir, train_transforms)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader1 = torch.utils.data.DataLoader(datasets.ImageFolder(
        img_path, valdir1, val_transforms),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    val_loader2 = torch.utils.data.DataLoader(datasets.ImageFolder(
        img_path, valdir2, val_transforms),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, odr_optimizer, zsr_optimizer, epoch,
                             args)

        # train for one epoch
        train(train_loader,
              semantic_data,
              model,
              criterion,
              optimizer,
              odr_optimizer,
              zsr_optimizer,
              epoch,
              is_fix=args.is_fix)

        # evaluate on validation set
        prec1 = validate(val_loader1, val_loader2, semantic_data, model,
                         criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        # save model
        if args.is_fix:
            save_path = os.path.join(args.save_path, 'fix.model')
        else:
            save_path = os.path.join(
                args.save_path,
                args.arch + ('_{:.4f}.model').format(best_prec1))
        if is_best:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    #'optimizer' : optimizer.state_dict(),
                },
                filename=save_path)
            print('saving!!!!')
Ejemplo n.º 15
0
    def __init__(self, args):
        self.args = args

        self.loader_input = args.loader_input
        self.loader_label = args.loader_label
        self.prefetch = args.prefetch

        self.split_test = args.split_test
        self.split_train = args.split_train
        self.dataset_test_name = args.dataset_test
        self.dataset_train_name = args.dataset_train
        self.resolution = (args.resolution_wide, args.resolution_high)

        self.input_filename_test = args.input_filename_test
        self.label_filename_test = args.label_filename_test
        self.input_filename_train = args.input_filename_train
        self.label_filename_train = args.label_filename_train

        if self.dataset_train_name == 'LSUN':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=args.dataroot,
                classes=['bedroom_train'],
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'CASIA':
            self.dataset_train = datasets.ImageFolder(
                root=self.args.dataroot,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'RECONSTRUCTION':
            self.dataset_train = datasets.RECONSTRUCTION(
                root=self.args.dataroot,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ]),
                nc=args.nchannels)

        elif self.dataset_train_name == 'CELEBA':
            self.dataset_train = datasets.ImageFolder(
                root=self.args.dataroot +
                "/train",  # change it back to train before training
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.RandomHorizontalFlip(),
                    # transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.01),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'SSFF':
            self.dataset_train = datasets.ImageFolder(
                root=self.args.dataroot + "/train",
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'CIFAR10' or self.dataset_train_name == 'CIFAR100':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataroot,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.RandomHorizontalFlip(),
                    transforms.ColorJitter(brightness=0.3,
                                           contrast=0.3,
                                           saturation=0.2,
                                           hue=0.01),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ]))

        elif self.dataset_train_name == 'GMM':
            self.dataset_train = datasets.GMM(args)

        elif self.dataset_train_name == 'GMMRing':
            self.dataset_train = datasets.GMM_Ring(args)

        elif self.dataset_train_name == 'CocoCaption' or self.dataset_train_name == 'CocoDetection':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataroot,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'STL10' or self.dataset_train_name == 'SVHN':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataroot,
                split='train',
                download=True,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'MNIST':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataroot,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307, ), (0.3081, ))
                ]))

        elif self.dataset_train_name == 'ImageNet':
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            # self.dataset_train = datasets.ImageFolder(root=os.path.join(self.args.dataroot, "train"),
            #     transform=transforms.Compose([
            #         transforms.Scale(self.resolution),
            #         transforms.CenterCrop(self.resolution),
            #         transforms.RandomHorizontalFlip(),
            #         transforms.ToTensor(),
            #         # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            #         normalize,
            #        ])
            #     )
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataroot,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                    normalize
                ]),
                prefetch=self.args.prefetch)

        elif self.dataset_train_name == 'FRGC':
            self.dataset_train = datasets.ImageFolder(
                root=self.args.dataroot + self.args.input_filename_train,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'Folder':
            self.dataset_train = datasets.ImageFolder(
                root=self.args.dataroot + self.args.input_filename_train,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_train_name == 'FileList':
            self.dataset_train = datasets.FileList(
                self.input_filename_train,
                self.label_filename_train,
                self.split_train,
                self.split_test,
                train=True,
                transform_train=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                transform_test=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                loader_input=self.loader_input,
                loader_label=self.loader_label,
            )

        elif self.dataset_train_name == 'FolderList':
            self.dataset_train = datasets.FileList(
                self.input_filename_train,
                self.label_filename_train,
                self.split_train,
                self.split_test,
                train=True,
                transform_train=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                transform_test=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                loader_input=self.loader_input,
                loader_label=self.loader_label,
            )

        else:
            raise (Exception("Unknown Dataset"))

        if self.dataset_test_name == 'LSUN':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=args.dataroot,
                classes=['bedroom_val'],
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'CIFAR10' or self.dataset_test_name == 'CIFAR100':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataroot,
                train=False,
                download=True,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'CELEBA':
            self.dataset_test = datasets.ImageFolder(
                root=self.args.dataroot + "/test",
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'RECONSTRUCTION':
            pass

        elif self.dataset_test_name == 'SSFF':
            self.dataset_test = datasets.ImageFolder(
                root=self.args.dataroot + "/test",
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'CocoCaption' or self.dataset_test_name == 'CocoDetection':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataroot,
                train=False,
                download=True,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'STL10' or self.dataset_test_name == 'SVHN':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataroot,
                split='test',
                download=True,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'MNIST':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataroot,
                train=False,
                download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307, ), (0.3081, ))
                ]))

        elif self.dataset_test_name == 'ImageNet':
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataroot,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                    normalize
                ]))

        elif self.dataset_test_name == 'FRGC':
            self.dataset_test = datasets.ImageFolder(
                root=self.args.dataroot + self.args.input_filename_test,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'Folder':
            self.dataset_test = datasets.ImageFolder(
                root=self.args.dataroot + self.args.input_filename_test,
                transform=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]))

        elif self.dataset_test_name == 'FileList':
            self.dataset_test = datasets.FileList(
                self.input_filename_test,
                self.label_filename_test,
                self.split_train,
                self.split_test,
                train=True,
                transform_train=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                loader_input=self.loader_input,
                loader_label=self.loader_label,
            )

        elif self.dataset_test_name == 'FolderList':
            self.dataset_test = datasets.FileList(
                self.input_filename_test,
                self.label_filename_test,
                self.split_train,
                self.split_test,
                train=True,
                transform_train=transforms.Compose([
                    transforms.Scale(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                loader_input=self.loader_input,
                loader_label=self.loader_label,
            )
        elif self.dataset_test_name == 'GMM':
            self.dataset_test = datasets.GMM(args)

        elif self.dataset_test_name == 'GMMRing':
            self.dataset_test = datasets.GMM_Ring(args)

        elif self.dataset_test_name is None:
            pass

        else:
            raise (Exception("Unknown Dataset"))
Ejemplo n.º 16
0
    def __init__(self, args):
        self.args = args

        self.loader_input = args.loader_input
        self.loader_label = args.loader_label

        self.dataset_options = args.dataset_options

        self.split = args.split
        self.dataset_test_name = args.dataset_test
        self.dataset_train_name = args.dataset_train

        self.train_dev_percent = args.train_dev_percent
        self.test_dev_percent = args.test_dev_percent

        self.resolution = (args.resolution_wide, args.resolution_high)
        self.input_size = (args.input_wide, args.input_high)

        self.input_filename_test = args.input_filename_test
        self.label_filename_test = args.label_filename_test
        self.input_filename_train = args.input_filename_train
        self.label_filename_train = args.label_filename_train

        if self.dataset_train_name == 'LSUN':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                db_path=args.dataset_root_train,
                classes=['bedroom_train'],
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])
            )

        elif self.dataset_train_name == 'CIFAR10' or self.dataset_train_name == 'CIFAR100':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataset_root_train, train=True, download=True,
                transform=transforms.Compose([
                    transforms.RandomCrop(self.resolution, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])
            )

        elif self.dataset_train_name == 'CocoCaption' or self.dataset_train_name == 'CocoDetection':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataset_root_train, train=True, download=True,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])
            )

        elif self.dataset_train_name == 'STL10' or self.dataset_train_name == 'SVHN':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataset_root_train, split='train', download=True,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])
            )

        elif self.dataset_train_name == 'MNIST':
            self.dataset_train = getattr(datasets, self.dataset_train_name)(
                root=self.args.dataset_root_train, train=True, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
            )

        elif self.dataset_train_name == 'ImageNet':
            normalize = transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
            self.dataset_train = datasets.ImageFolder(
                root=self.args.dataset_root_train + self.args.input_filename_train,
                transform=transforms.Compose([
                    transforms.RandomSizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ])
            )

        elif self.dataset_train_name == 'FRGC':
            self.dataset_train = datasets.ImageFolder(
                root=self.args.dataset_root_train + self.args.input_filename_train,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])
            )

        elif self.dataset_train_name == 'Folder':
            self.dataset_train = datasets.ImageFolder(
                root=os.path.join(self.args.dataset_root_train, self.args.input_filename_train),
                transform=transforms.Compose(
                    self.preprocess(self.args.preprocess_train),
                )
            )

        elif self.dataset_train_name == 'FileListLoader':
            self.dataset_train = datasets.FileListLoader(
                self.input_filename_train, self.args.dataset_root_train,
                self.split,
                transform=transforms.Compose(
                    self.preprocess(self.args.preprocess_train)
                ),
                loader=self.loader_input,
            )

        elif self.dataset_train_name == 'FolderListLoader':
            self.dataset_train = datasets.FileListLoader(
                self.input_filename_train, self.label_filename_train,
                self.split_train, self.split_test, train=True,
                transform_train=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                transform_test=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                loader_input=self.loader_input,
                loader_label=self.loader_label,
            )

        elif self.dataset_train_name == 'CSVListLoader':
            self.dataset_train = datasets.CSVListLoader(
                self.input_filename_train, self.args.dataset_root_train, 
                self.split,
                transform=transforms.Compose(
                    self.preprocess(self.args.preprocess_train)
                ),
                loader=self.loader_input,
            )

        elif self.dataset_train_name == 'TripletDataLoader':
            self.dataset_train = datasets.TripletDataLoader(
                self.args.dataset_root_train, self.input_filename_train, 
                self.args.batch_size_image,
                transform=transforms.Compose(
                    self.preprocess(self.args.preprocess_train)
                ),
            )

        elif self.dataset_train_name == 'TripletDataLoader_ImageRetrieval':
            if "\n" in self.input_filename_train:
                self.input_filename_train = self.input_filename_train.split('\n')
            else:
                raise(Exception("Format of input filename is wrong!"))

            assert(len(self.input_filename_train) == 2)
            self.dataset_train = datasets.TripletDataLoader_ImageRetrieval(
                self.args.dataset_root_train, 
                self.input_filename_train[0], self.input_filename_train[1], 
                self.args.batch_size_image,
                transform=transforms.Compose(
                    self.preprocess(self.args.preprocess_train)
                ),
            )

        elif self.dataset_train_name == 'Featpair':
            self.dataset_train = datasets.Featpair(
                self.args.input_filename_train,
                self.args.pair_index_filename,
                self.args.if_norm,
                self.args.num_images, self.args.in_dims,
                self.args.template_filename,
                )

        elif self.dataset_train_name == 'ClassPairDataLoader':
            self.dataset_train = datasets.ClassPairDataLoader(
                self.args.input_filename_train,
                self.args.if_norm,
                self.args.batch_size_image,
                )

        elif self.dataset_train_name == 'Featarray':
            self.dataset_train = datasets.Featarray(
                self.args.input_filename_train,
                self.args.if_norm,
                self.args.num_images, self.args.in_dims,
                )

        elif self.dataset_train_name == None:
            print("No training data assigned!")
        else:
            raise(Exception("Unknown Dataset"))

        if self.dataset_test_name == 'LSUN':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                db_path=args.dataset_root_test, classes=['bedroom_val'],
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.CenterCrop(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])
            )

        elif self.dataset_test_name == 'CIFAR10' or self.dataset_test_name == 'CIFAR100':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataset_root_test, train=False, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])
            )

        elif self.dataset_test_name == 'CocoCaption' or self.dataset_test_name == 'CocoDetection':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataset_root_test, train=False, download=True,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])
            )

        elif self.dataset_test_name == 'STL10' or self.dataset_test_name == 'SVHN':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataset_root_test, split='test', download=True,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])
            )

        elif self.dataset_test_name == 'MNIST':
            self.dataset_test = getattr(datasets, self.dataset_test_name)(
                root=self.args.dataset_root_test, train=False, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
            )

        elif self.dataset_test_name == 'ImageNet':
            normalize = transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
            self.dataset_test = datasets.ImageFolder(
                root=self.args.dataset_root_test + self.args.input_filename_test,
                transform=transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ])
            )

        elif self.dataset_test_name == 'FRGC':
            self.dataset_test = datasets.ImageFolder(
                root=self.args.dataset_root_test + self.args.input_filename_test,
                transform=transforms.Compose([
                    transforms.Resize(self.resolution),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])
            )

        elif self.dataset_test_name == 'Folder':
            self.dataset_test = datasets.ImageFolder(
                root=os.path.join(self.args.dataset_root_test, self.args.input_filename_test),
                transform=transforms.Compose(
                    self.preprocess(self.args.preprocess_test),
                )
            )

        elif self.dataset_test_name == 'FileListLoader':
            self.dataset_test = datasets.FileListLoader(
                self.input_filename_test, self.args.dataset_root_test,
                self.split,
                transform=transforms.Compose(
                    self.preprocess(self.args.preprocess_test)
                ),
                loader=self.loader_input,
            )

        elif self.dataset_test_name == 'FolderListLoader':
            self.dataset_test = datasets.FileListLoader(
                self.input_filename_test, self.label_filename_test,
                self.split_train, self.split_test, train=False,
                transform_test=transforms.Compose(
                    self.preprocess(self.args.preprocess_test)
                ),
                loader_input=self.loader_input,
                loader_label=self.loader_label,
            )
        
        elif self.dataset_test_name == 'CSVListLoader':
            self.dataset_test = datasets.CSVListLoader(
                self.input_filename_test, self.args.dataset_root_test,
                self.split, 
                transform=transforms.Compose(
                    self.preprocess(self.args.preprocess_test)
                ),
                loader=self.loader_input,
            )

        elif self.dataset_test_name == 'TripletDataLoader':
            self.dataset_test = datasets.TripletDataLoader(
                self.args.dataset_root_test, self.input_filename_test, 
                self.args.batch_size_image,
                transform=transforms.Compose(
                    self.preprocess(self.args.preprocess_test)
                ),
            )

        elif self.dataset_test_name == 'TripletDataLoader_ImageRetrieval':
            if "\n" in self.input_filename_test:
                self.input_filename_test = self.input_filename_test.split('\n')
            else:
                raise(Exception("Format of input filename is wrong!"))

            assert(len(self.input_filename_test) == 2)
            self.dataset_test = datasets.TripletDataLoader_ImageRetrieval(
                self.args.dataset_root_test, 
                self.input_filename_test[0], self.input_filename_test[1],
                self.args.batch_size_image,
                transform=transforms.Compose(
                    self.preprocess(self.args.preprocess_test)
                ),
            )

        elif self.dataset_test_name == 'Featarray':
            self.dataset_test = datasets.Featarray(
                self.args.input_filename_test,
                self.args.if_norm,
                self.args.num_images, self.args.in_dims,
                )

        elif self.dataset_test_name == None:
            print("No testing data assigned!")
        else:
            raise(Exception("Unknown Dataset"))
Ejemplo n.º 17
0
def main():

    global args, best_prec1
    args = parser.parse_args()

    # create model
    if args.arch=='alexnet':
        model = model_list.alexnet(pretrained=args.pretrained)
        input_size = 224
    elif args.arch=='vgg16':
        model = model_list.vgg_net(pretrained=args.pretrained)
        input_size = 224
    elif args.arch=='vgg15_bwn':
        model = model_list.vgg_15(pretrained=args.pretrained)
        input_size = 224
    elif args.arch=='vgg15_bn_XNOR':
        model = model_list.vgg15_bn_XNOR(pretrained=args.pretrained)
        input_size = 224
    elif args.arch=='vgg15ab':
        model = model_list.vgg15ab(pretrained=args.pretrained)
        input_size = 224
    elif args.arch=='sq':
        model = model_list.squeezenet1_1()
        input_size = 224
    else:
        raise Exception('Model not supported yet')

    # if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
    #     pass
    #     model.features = torch.nn.DataParallel(model.features)
    #     model.cuda()
    # else:
    # model = torch.nn.DataParallel(model).cuda()
    model.cuda()
    # model.features = torch.nn.DataParallel(model.features)
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    # optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                             momentum=args.momentum,
    #                              # betas=(0.0, 0.999),
    #                              weight_decay=args.weight_decay)
    optimizer = torch.optim.Adam(model.parameters(), args.lr,
                                 betas=(0.0, 0.999),
                                weight_decay=args.weight_decay)
# scratch
#     for m in model.modules():
#         if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
#             c = float(m.weight.data[0].nelement())
#             m.weight.data = m.weight.data.normal_(0, 2.0/c)
#         elif isinstance(m, nn.BatchNorm2d):
#             m.weight.data = m.weight.data.zero_().add(1.0)
#             m.bias.data = m.bias.data.zero_()

    # optionally resume from a checkpoint
    # if args.resume:
    #     if os.path.isfile(args.resume):
    #         print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.pretrained)
    #         # TODO: Temporary remake
    #         # args.start_epoch = 0
    #         # best_prec1 = 0.0
    #         # model.features = torch.nn.DataParallel(model.features)
    try:
        args.start_epoch = checkpoint['epoch']
        if args.pretrained:
            best_prec1 = 0
        model = torch.nn.DataParallel(model)
        model.load_state_dict(checkpoint['state_dict'])
    except KeyError:
        model.load_state_dict(checkpoint)
        pass
    #
    #
    #
    #
    #         # optimizer.load_state_dict(checkpoint['optimizer'])
    #         print("=> loaded checkpoint '{}' (epoch {})"
    #               .format(args.resume, args.start_epoch))
    #         del checkpoint
    #     else:
    #         print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code

    if args.caffe_data:
        print('==> Using Caffe Dataset')
        cwd = os.getcwd()
        sys.path.append(cwd+'/../')
        import datasets as datasets
        import datasets.transforms as transforms
        if not os.path.exists(args.data+'/imagenet_mean.binaryproto'):
            print("==> Data directory"+args.data+"does not exits")
            print("==> Please specify the correct data path by")
            print("==>     --data <DATA_PATH>")
            return

        normalize = transforms.Normalize(
                meanfile=args.data+'/imagenet_mean.binaryproto')


        train_dataset = datasets.ImageFolder(
            args.data,
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
                transforms.RandomSizedCrop(input_size),
            ]),
            Train=True)

        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True, sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(args.data, transforms.Compose([
                transforms.ToTensor(),
                normalize,
                transforms.CenterCrop(input_size),
            ]),
            Train=False),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
    elif args.cifar:
        import torchvision.transforms as transforms
        import torchvision
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                                  shuffle=True, num_workers=2)

        testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                               download=True, transform=transform)
        val_loader = torch.utils.data.DataLoader(testset, batch_size=100,
                                                 shuffle=False, num_workers=2)

        classes = ('plane', 'car', 'bird', 'cat',
                   'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


    else:
        print('==> Using Pytorch Dataset')
        import torchvision
        import torchvision.transforms as transforms
        import torchvision.datasets as datasets
        # traindir = os.path.join(args.data, 'train')
        # valdir = os.path.join(args.data, 'test')
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        if True:
        #     train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        # else:
            train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size//2 if args.arch.startswith('vgg') else args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
    # print (model)

    # define the binarization operator
    global bin_op

    bin_op = util.BinOp(model)


    if args.evaluate:
        if args.binarize:
            bin_op.binarization()
            save_checkpoint(model.state_dict(), False, filename='{}/{}_bin_'.format(args.workdir, args.arch))
            bin_op.restore()
        # bin_op.binarization()
        # save_checkpoint(model.state_dict(), False, 'vgg_binarized')
        # bin_op.restore()
        validate(val_loader, model, criterion)
        return
    val_prec_list = []
    writer = SummaryWriter(args.workdir+'/runs/loss_graph')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, writer)

        # evaluate on validation set
        prec1, prec5 = validate(val_loader, model, criterion)
        val_prec_list.append(prec1)
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best, filename='{}/{}_'.format(args.workdir, args.arch))
        writer.add_scalar('top1 accuracy', prec1, epoch)
        writer.add_scalar('top5 accuracy', prec5, epoch)
        writer.add_scalar('learning rate', args.lr, epoch)
    print(val_prec_list)
Ejemplo n.º 18
0
    # 数据扩充和训练标准化
    if args.arch == "inception":
        input_size = 229
    else:
        input_size = 224
    """中心裁剪输出input_size,转换为Tensor,标准化"""
    test_transforms = transforms.Compose([
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    print("Initializing Datasets and Dataloaders...")
    # 创建测试数据集
    datapath = Path(args.data)
    testdataset = datasets.ImageFolder(str(datapath / "test"), test_transforms)
    dataloader = DataLoader(dataset=testdataset,
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            shuffle=False,
                            pin_memory=torch.cuda.is_available())

    # 输出结果保存地址
    resultpath = model_path.parent / "test_result.pkl"

    # 执行测试
    test_acc, test_ap = test(model, dataloader, args.num_workers,
                             args.batch_size, resultpath)
    print("test_acc", test_acc.value())
    print("test_ap", test_ap.value())
    print("test_auc", test_auc.value()[0])
Ejemplo n.º 19
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    cudnn.deterministic = True
    cudnn.benchmark = True

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)
    ''' random seed '''
    if args.seed is not None:
        random.seed(args.seed)
    else:
        args.seed = random.randint(1, 10000)

    torch.manual_seed(args.seed)
    print('==> random seed:', args.seed)
    ''' data load '''
    if args.data == 'cub':
        img_root = '/userhome/raw_data/CUB_200_2011/CUB_200_2011/images'
        traindir = os.path.join(img_root, '../train.list')
        valdir = os.path.join(img_root, '../test.list')
        args.num_cls = 200
    elif args.data == 'car':
        img_root = '/userhome/raw_data/car_196/car_ims'
        traindir = os.path.join(img_root, '../train.list')
        valdir = os.path.join(img_root, '../test.list')
        args.num_cls = 196
    elif args.data == 'air':
        img_root = '/userhome/raw_data/aircraft/images'
        traindir = os.path.join(img_root, '../train.list')
        valdir = os.path.join(img_root, '../test.list')
        args.num_cls = 100
    elif args.data == 'dog':
        img_root = '/userhome/raw_data/dogs-120/Images'
        traindir = os.path.join(img_root, '../train.list')
        valdir = os.path.join(img_root, '../test.list')
        args.num_cls = 120

    train_transforms, val_transforms = preprocess_strategy(args)

    train_dataset = datasets.ImageFolder(img_root, traindir, train_transforms)
    val_dataset = datasets.ImageFolder(img_root, valdir, val_transforms)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    ''' model building '''
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        best_prec1 = 0
        model, criterion = models.__dict__[args.arch](pretrained=True,
                                                      args=args)
    else:
        print("=> creating model '{}'".format(args.arch))
        model, criterion = models.__dict__[args.arch](args=args)
    print("=> is the backbone fixed: '{}'".format(args.is_fix))

    if args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    criterion = criterion.cuda(args.gpu)
    ''' optimizer '''
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       model.parameters()),
                                lr=args.lr,
                                momentum=args.momentum)

    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    ''' optionally resume from a checkpoint'''
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #args.start_epoch = checkpoint['epoch']
            if (best_prec1 == 0):
                best_prec1 = checkpoint['best_prec1']
            print('=> pretrained acc {:.4F}'.format(best_prec1))
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    ''' training '''
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        exp_lr_scheduler.step(epoch)

        # train for one epoch
        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              is_fix=args.is_fix)

        # evaluate on validation set
        prec1 = validate(val_loader, model, val_dataset)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        # save model
        if args.is_fix:
            save_path = os.path.join(args.save_path, 'fix.model')
        else:
            save_path = os.path.join(
                args.save_path,
                args.arch + ('_{:.4f}.model').format(best_prec1))

        if is_best:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                },
                filename=save_path)
            print('saving!!!!')
def fetch_dataset(data_name):
    print('fetching data {}...'.format(data_name))
    if (data_name == 'MNIST'):
        train_dir = './data/{}/train'.format(data_name)
        test_dir = './data/{}/test'.format(data_name)
        train_dataset = datasets.MNIST(root=train_dir,
                                       train=True,
                                       download=True,
                                       transform=transforms.ToTensor())
        if (normalize):
            stats = make_stats(train_dataset, batch_size=128)
            train_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize(stats)])
            test_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize(stats)])
        else:
            train_transform = transforms.Compose([transforms.ToTensor()])
            test_transform = transforms.Compose([transforms.ToTensor()])
        train_dataset.transform = train_transform
        test_dataset = datasets.MNIST(root=test_dir,
                                      train=False,
                                      download=True,
                                      transform=test_transform)

    elif (data_name == 'EMNIST' or data_name == 'EMNIST_byclass'
          or data_name == 'EMNIST_bymerge' or data_name == 'EMNIST_balanced'
          or data_name == 'EMNIST_letters' or data_name == 'EMNIST_digits'
          or data_name == 'EMNIST_mnist'):
        train_dir = './data/{}/train'.format(data_name.split('_')[0])
        test_dir = './data/{}/test'.format(data_name.split('_')[0])
        transform = transforms.Compose([transforms.ToTensor()])
        split = 'balanced' if len(
            data_name.split('_')) == 1 else data_name.split('_')[1]
        train_dataset = datasets.EMNIST(root=train_dir,
                                        split=split,
                                        branch=branch,
                                        train=True,
                                        download=True,
                                        transform=transform)
        test_dataset = datasets.EMNIST(root=test_dir,
                                       split=split,
                                       branch=branch,
                                       train=False,
                                       download=True,
                                       transform=transform)

    elif (data_name == 'FashionMNIST'):
        train_dir = './data/{}/train'.format(data_name)
        test_dir = './data/{}/test'.format(data_name)
        transform = transforms.Compose([transforms.ToTensor()])
        train_dataset = datasets.FashionMNIST(root=train_dir,
                                              train=True,
                                              download=True,
                                              transform=transform)
        test_dataset = datasets.FashionMNIST(root=test_dir,
                                             train=False,
                                             download=True,
                                             transform=transform)

    elif (data_name == 'CIFAR10'):
        train_dir = './data/{}/train'.format(data_name)
        test_dir = './data/{}/validation'.format(data_name)
        train_dataset = datasets.CIFAR10(train_dir,
                                         train=True,
                                         transform=transforms.ToTensor(),
                                         download=True)
        if (normalize):
            stats = make_stats(train_dataset, batch_size=128)
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(stats)
            ])
            test_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize(stats)])
        else:
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([transforms.ToTensor()])
        train_dataset.transform = train_transform
        test_dataset = datasets.CIFAR10(test_dir,
                                        train=False,
                                        transform=test_transform,
                                        download=True)

    elif (data_name == 'CIFAR100'):
        train_dir = './data/{}/train'.format(data_name)
        test_dir = './data/{}/validation'.format(data_name)
        train_dataset = datasets.CIFAR100(train_dir,
                                          branch=branch,
                                          train=True,
                                          transform=transforms.ToTensor(),
                                          download=True)
        if (normalize):
            stats = make_stats(train_dataset, batch_size=128)
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(stats)
            ])
            test_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize(stats)])
        else:
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([transforms.ToTensor()])
        train_dataset.transform = train_transform
        test_dataset = datasets.CIFAR100(test_dir,
                                         branch=branch,
                                         train=False,
                                         transform=test_transform,
                                         download=True)

    elif (data_name == 'SVHN'):
        train_dir = './data/{}/train'.format(data_name)
        test_dir = './data/{}/validation'.format(data_name)
        train_dataset = datasets.SVHN(train_dir,
                                      split='train',
                                      transform=transforms.ToTensor(),
                                      download=True)
        if (normalize):
            stats = make_stats(train_dataset, batch_size=128)
            train_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize(stats)])
            test_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize(stats)])
        else:
            train_transform = transforms.Compose([transforms.ToTensor()])
            test_transform = transforms.Compose([transforms.ToTensor()])
        train_dataset.transform = train_transform
        test_dataset = datasets.SVHN(test_dir,
                                     split='test',
                                     transform=test_transform,
                                     download=True)

    elif (data_name == 'ImageNet'):
        train_dir = './data/{}/train'.format(data_name)
        test_dir = './data/{}/validation'.format(data_name)
        train_dataset = datasets.ImageFolder(train_dir,
                                             transform=transforms.ToTensor())
        if (normalize):
            stats = make_stats(train_dataset, batch_size=128)
            train_transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(stats)
            ])
            test_transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(stats)
            ])
        else:
            train_transform = transforms.Compose(
                [transforms.Resize((224, 224)),
                 transforms.ToTensor()])
            test_transform = transforms.Compose(
                [transforms.Resize((224, 224)),
                 transforms.ToTensor()])
        train_dataset.transform = train_transform
        test_dataset = datasets.ImageFolder(test_dir, transform=test_transform)

    elif (data_name == 'CUB2011'):
        train_dir = './data/{}/train'.format(data_name.split('_')[0])
        test_dir = './data/{}/validation'.format(data_name.split('_')[0])
        train_dataset = datasets.CUB2011(train_dir,
                                         transform=transforms.Compose([
                                             transforms.Resize((224, 224)),
                                             transforms.ToTensor()
                                         ]),
                                         download=True)
        if (normalize):
            stats = make_stats(train_dataset, batch_size=128)
            train_transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(stats)
            ])
            test_transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(stats)
            ])
        else:
            train_transform = transforms.Compose(
                [transforms.Resize((224, 224)),
                 transforms.ToTensor()])
            test_transform = transforms.Compose(
                [transforms.Resize((224, 224)),
                 transforms.ToTensor()])
        train_dataset.transform = train_transform
        test_dataset = datasets.CUB2011(test_dir,
                                        transform=test_transform,
                                        download=True)

    elif (data_name == 'WheatImage' or data_name == 'WheatImage_binary'
          or data_name == 'WheatImage_six'):
        train_dir = './data/{}/train'.format(data_name.split('_')[0])
        test_dir = './data/{}/validation'.format(data_name.split('_')[0])
        label_mode = 'six' if len(
            data_name.split('_')) == 1 else data_name.split('_')[1]
        train_dataset = datasets.WheatImage(train_dir,
                                            label_mode=label_mode,
                                            transform=transforms.Compose([
                                                transforms.Resize((224, 288)),
                                                transforms.ToTensor()
                                            ]))
        if (normalize):
            stats = make_stats(train_dataset, batch_size=128)
            train_transform = transforms.Compose([
                transforms.Resize((224, 288)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(stats)
            ])
            test_transform = transforms.Compose([
                transforms.Resize((224, 288)),
                transforms.ToTensor(),
                transforms.Normalize(stats)
            ])
        else:
            train_transform = transforms.Compose([
                transforms.Resize((224, 288)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose(
                [transforms.Resize((224, 288)),
                 transforms.ToTensor()])
        train_dataset.transform = train_transform
        test_dataset = datasets.WheatImage(test_dir,
                                           label_mode=label_mode,
                                           transform=test_transform)

    elif (data_name == 'CocoDetection'):
        train_dir = './data/Coco/train2017'
        train_ann = './data/Coco/annotations/instances_train2017.json'
        test_dir = './data/Coco/val2017'
        test_ann = './data/Coco/annotations/instances_val2017.json'
        transform = transforms.Compose(
            [transforms.Resize((224, 224)),
             transforms.ToTensor()])
        train_dataset = datasets.CocoDetection(train_dir,
                                               train_ann,
                                               transform=transform)
        test_dataset = datasets.CocoDetection(test_dir,
                                              test_ann,
                                              transform=transform)

    elif (data_name == 'CocoCaptions'):
        train_dir = './data/Coco/train2017'
        train_ann = './data/Coco/annotations/captions_train2017.json'
        test_dir = './data/Coco/val2017'
        test_ann = './data/Coco/annotations/captions_val2017.json'
        transform = transforms.Compose(
            [transforms.Resize((224, 224)),
             transforms.ToTensor()])
        train_dataset = datasets.CocoCaptions(train_dir,
                                              train_ann,
                                              transform=transform)
        test_dataset = datasets.CocoCaptions(test_dir,
                                             test_ann,
                                             transform=transform)

    elif (data_name == 'VOCDetection'):
        train_dir = './data/VOC/VOCdevkit'
        test_dir = './data/VOC/VOCdevkit'
        transform = transforms.Compose(
            [transforms.Resize((224, 224)),
             transforms.ToTensor()])
        train_dataset = datasets.VOCDetection(train_dir,
                                              'trainval',
                                              transform=transform)
        test_dataset = datasets.VOCDetection(test_dir,
                                             'test',
                                             transform=transform)

    elif (data_name == 'VOCSegmentation'):
        train_dir = './data/VOC/VOCdevkit'
        test_dir = './data/VOC/VOCdevkit'
        transform = transforms.Compose(
            [transforms.Resize((224, 224)),
             transforms.ToTensor()])
        train_dataset = datasets.VOCSegmentation(train_dir,
                                                 'trainval',
                                                 transform=transform)
        test_dataset = datasets.VOCSegmentation(test_dir,
                                                'test',
                                                transform=transform)

    elif (data_name == 'MOSI' or data_name == 'MOSI_binary'
          or data_name == 'MOSI_five' or data_name == 'MOSI_seven'
          or data_name == 'MOSI_regression'):
        train_dir = './data/{}'.format(data_name.split('_')[0])
        test_dir = './data/{}'.format(data_name.split('_')[0])
        label_mode = 'five' if len(
            data_name.split('_')) == 1 else data_name.split('_')[1]
        train_dataset = datasets.MOSI(train_dir,
                                      split='trainval',
                                      label_mode=label_mode,
                                      download=True)
        stats = make_stats(train_dataset, batch_size=1)
        train_transform = transforms.Compose([transforms.Normalize(stats)])
        test_transform = transforms.Compose([transforms.Normalize(stats)])
        train_dataset.transform = train_transform
        test_dataset = datasets.MOSI(test_dir,
                                     split='test',
                                     label_mode=label_mode,
                                     download=True,
                                     transform=test_transform)

    elif (data_name == 'Kodak'):
        train_dataset = None
        transform = transforms.Compose([transforms.ToTensor()])
        test_dir = './data/{}'.format(data_name)
        train_dataset = datasets.ImageFolder(test_dir, transform)
        test_dataset = datasets.ImageFolder(test_dir, transform)

    elif (data_name == 'UCID'):
        train_dataset = None
        transform = transforms.Compose([transforms.ToTensor()])
        test_dir = './data/{}'.format(data_name)
        train_dataset = datasets.ImageFolder(test_dir, transform)
        test_dataset = datasets.ImageFolder(test_dir, transform)
    else:
        raise ValueError('Not valid dataset name')
    print('data ready')
    return train_dataset, test_dataset