Beispiel #1
0
def create_dataset(ds_name, path, transform, classes=None):
    dataset = None

    if ds_name == 'imagenet':
        dataset = dset.ImageFolder(root=path, transform=transform)
    elif ds_name == 'lsun':
        # LSUN Cat
        if any('cat' in c for c in classes):
            dataset_idx = None
            for c in classes:
                if 'train' in c:
                    dataset_idx = 0
                elif 'val' in c:
                    dataset_idx = 1

            dataset = dset.ImageFolder(path, transform=transform)
            torch.manual_seed(42)
            train_val_datasets = torch.utils.data.random_split(
                dataset, [len(dataset) - 5000, 5000])
            dataset = train_val_datasets[dataset_idx]
        # LSUN Scene classes
        else:
            dataset = dset.LSUN(path, classes=classes, transform=transform)
    else:
        raise NotImplementedError()

    return dataset
Beispiel #2
0
def load_dataset():
    if opt.dataset == 'MNIST':
        out_dir = './dataset/MNIST'
        return datasets.MNIST(root=out_dir, train=True, download=True,
                              transform=transforms.Compose(
                                  [
                                      transforms.Resize(opt.imageSize),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                  ]
                              )), 1

    elif opt.dataset == 'lsun':
        out_dir = './dataset/lsun'
        return datasets.LSUN(root=out_dir, classes=['bedroom_train'],
                             transform=transforms.Compose([
                                 transforms.Resize(opt.imageSize),
                                 transforms.CenterCrop(opt.imageSize),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                             ])), 3

    elif opt.dataset == 'cifar10':
        out_dir = './dataset/cifar10'
        return datasets.CIFAR10(root=out_dir, download=True, train=True,
                                transform=transforms.Compose([
                                    transforms.Resize(opt.imageSize),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ])), 3

    raise ValueError('No valid dataset found in {}'.format(opt.dataset))
Beispiel #3
0
def load_data(image_data_type,
              path_to_folder,
              data_transform,
              batch_size,
              classes=None,
              num_workers=5):
    # torch issue
    # https://github.com/pytorch/pytorch/issues/22866
    torch.set_num_threads(1)
    if image_data_type == 'lsun':
        dataset = datasets.LSUN(path_to_folder,
                                classes=classes,
                                transform=data_transform)
    elif image_data_type == "image_folder":
        dataset = datasets.ImageFolder(root=path_to_folder,
                                       transform=data_transform)
    else:
        raise ValueError("Invalid image data type")
    dataset_loader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=num_workers,
                                                 drop_last=True,
                                                 pin_memory=True)
    return dataset_loader
Beispiel #4
0
def get_dataset(dataset_name, image_size, dataroot='/tmp'):
    assert dataset_name in ALLOWED_DATASET


    if name == 'fake':
        dataset = dset.FakeData(image_size=(3, image_size, image_size),
                                transform=transforms.ToTensor())
        in_channels = 3
    else:
        in_channels = 1 if datasete_name == 'mnist' else 3

        transform = [transforms.Resize(image_size)]
        if dataset_name != 'mnits':
            transform.append(transformss.CenterCrop(image_size))
        transform += [
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5, ) * in_channels,
                (0.5, ) * in_channels),
        ]
        transform = transforms.Compose(trans)

        if dataset_name in ['imagenet', 'folder', 'lfw']:
            dataset = dset.ImageFolder(root=root, transform=transform)
        elif dataset == 'lsun':
            dataset = dset.LSUN(root=root, classes=['bedroom_train'],
                                transform=transform)
        elif dataset == 'cifar10':
            dataset = dset.CIFAR10(root=root, download=True,
                                   transform=transform)
        elif dataset == 'mnist':
            dataset = dset.MNIST(root=dataroot, download=True,
                                 transform=transform)

    return dataset, in_channels
Beispiel #5
0
def get_loader(image_path, image_size, dataset, batch_size, num_workers=2):
    """Build and return data loader."""
    
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        # transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    if dataset == 'LSUN':
        dataset = datasets.LSUN(image_path, classes=['church_outdoor_train'], transform=transform)
    elif dataset == 'CelebA_FD':
        dataset = datasets.ImageFolder(image_path, transform=transform)
    elif dataset == 'cifar':
        dataset = datasets.CIFAR10(image_path, transform=transform, download=True)
    elif dataset == 'CelebA':
        transform = transforms.Compose([
            transforms.CenterCrop(160),
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = datasets.ImageFolder(image_path+'/CelebA', transform=transform)
    
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size, 
                                              shuffle=True,
                                              num_workers=2,
                                              drop_last= True)
    return data_loader
Beispiel #6
0
def download_lsun_data(data_name='',
                       data_path='./Data') -> [
                           torch.tensor,
                       ] * 4:
    os.makedirs(data_path, exist_ok=True)

    from torchvision import datasets
    data_train = datasets.LSUN('/mnt/sdb1/weit/code/Demo_DL_RL/Data',
                               classes='train')
    # data_train = datasets.LSUN(root=data_path, classes='test')
    print(data_train)

    def to_images(ary, ):
        # ary = ten.numpy()
        ary = ary / 255.0
        ary = np.transpose(ary, (0, 3, 1, 2))
        # ary = ary.reshape((-1, 3, 28, 28))
        # ary = np.pad(ary, ((0, 0), (0, 0), (2, 2), (2, 2)), mode='constant')
        # ary = ary.reshape((-1, 1, 28, 28, 3))
        ary = torch.tensor(ary, dtype=torch.float32)
        return ary

    def to_labels(ary, ):
        # ary = ten.numpy()
        ary = ary.reshape((-1, ))
        # classes_num = 10
        # data_sets = np.eye(classes_num)[data_sets] # one_hot
        ary = torch.tensor(ary, dtype=torch.long)
        return ary

    train_image = to_images(data_train.data)
    # train_label = to_labels(data_train.targets)
    # test_image = to_images(data_test.data)
    # test_label = to_labels(data_test.targets)
    return train_image, None,  # train_label, test_image, test_label
Beispiel #7
0
def get_dataset(dataset_name):
    if dataset_name == 'celeba':
        celeba = dset.ImageFolder(root='D:\celeba2',
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(138),
                                      transforms.Resize(64),
                                      transforms.ToTensor(),
                                  ]))
        return celeba

    if dataset_name == 'lsun':
        lsun = dset.LSUN(
            root=g.default_data_dir + 'lsun/',
            classes=['bedroom_train'],
            transform=transforms.Compose([
                transforms.Resize(64),
                transforms.CenterCrop(64),
                transforms.ToTensor(),
                #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]))
        return lsun

    if dataset_name == 'cifar10':
        # folder dataset
        cifar10 = dset.CIFAR10(
            root=g.default_data_dir,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(64),
                transforms.CenterCrop(64),
                transforms.ToTensor(),
                #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]))
        return cifar10
Beispiel #8
0
def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, drop_last=True, dataloader_args={},
                    resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True, tanh_scale=True,
                    normalize=False, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):
    # Make transform
    transform = make_transform(resize=resize, imsize=imsize,
                               centercrop=centercrop, centercrop_size=centercrop_size,
                               totensor=totensor, tanh_scale=tanh_scale,
                               normalize=normalize, norm_mean=norm_mean, norm_std=norm_std)
    # Make dataset
    if dataset_type in ['folder', 'imagenet', 'lfw']:
        # folder dataset
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.ImageFolder(root=data_path, transform=transform)
    elif dataset_type == 'lsun':
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.LSUN(root=data_path, classes=['bedroom_train'], transform=transform)
    elif dataset_type == 'cifar10':
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.CIFAR10(root=data_path, download=True, transform=transform)
    elif dataset_type == 'fake':
        dataset = dset.FakeData(image_size=(3, centercrop_size, centercrop_size), transform=transforms.ToTensor())
    assert dataset
    num_of_classes = len(dataset.classes)
    print("Data found!  # of images =", len(dataset), ", # of classes =", num_of_classes, ", classes:", dataset.classes)
    # Make dataloader from dataset
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **dataloader_args)
    return dataloader, num_of_classes
Beispiel #9
0
def get_dataset(opt):
    if opt.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=opt.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Scale(opt.imageSize),
                                       transforms.CenterCrop(opt.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == 'lsun':
        dataset = dset.LSUN(db_path=opt.dataroot,
                            classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Scale(opt.imageSize),
                                transforms.CenterCrop(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif opt.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=opt.dataroot,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    elif opt.dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
                                transform=transforms.ToTensor())
    assert dataset
    return dataset
Beispiel #10
0
def get_dataset(name, data_dir, size=64, lsun_categories=None):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())),
    ])

    if name == 'image':
        dataset = datasets.ImageFolder(data_dir, transform)
    elif name == 'npy':
        # Only support normalization for now
        dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
    elif name == 'cifar10':
        dataset = datasets.CIFAR10(root=data_dir,
                                   train=True,
                                   download=True,
                                   transform=transform)
    elif name == 'lsun':
        if lsun_categories is None:
            lsun_categories = 'train'
        dataset = datasets.LSUN(data_dir, lsun_categories, transform)
    elif name == 'lsun_class':
        dataset = datasets.LSUNClass(data_dir,
                                     transform,
                                     target_transform=(lambda t: 0))
    else:
        raise NotImplemented

    return dataset
Beispiel #11
0
def dataloaders(name):
    transforms_list = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    if name == "mnist":
        os.makedirs('./data/mnist', exist_ok=True)
        dataset = datasets.MNIST('./data/mnist', train=False, download=True,
                       transform=transforms_list)
    elif name == "custom_faces":
        dataset = CustomFaces(r"/home/saikat/PycharmProjects/DCGAN/data/custom_face", transform=transforms_list)
    elif name == "LSUN":
        #os.makedirs('./data/lsun', exist_ok=True)
        dataset = datasets.LSUN(r"/home/data/LSUN" ,["church_outdoor_train"], transforms_list)
    elif name == "imagenet":
        dataset = datasets.ImageFolder(r"/home/data/IMAGENET-1K/val", transforms_list)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)




    return dataloader
Beispiel #12
0
def get_loader(image_path, image_size, dataset, batch_size, num_workers=2):
    """Build and return data loader."""
    if dataset == 'ring':
        dg = data_generator()
        dataloader = dg.sample(batch_size)
        return torch.FloatTensor(dataloader)

    if dataset == 'grid':
        dg = grid_data_generator()
        dataloader = dg.sample(batch_size)
        return torch.FloatTensor(dataloader)

    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    #transform = transforms.Compose([
    #transforms.Resize((244, 244)),
    #transforms.ToTensor(),
    #transforms.Normalize(mean=(0.485, 0.456, 0.406),
    #std=(0.229, 0.224, 0.225))])
    if dataset == 'LSUN':
        dataset = datasets.LSUN(image_path,
                                classes=['church_outdoor_train'],
                                transform=transform)
        # 'church_outdoor' is one category of LSUN dataset
    elif dataset == 'CelebA_FD':
        dataset = datasets.ImageFolder(image_path, transform=transform)
    elif dataset == 'cifar':
        dataset = datasets.CIFAR10('../../data/cifar-10',
                                   transform=transform,
                                   download=True)
    elif dataset == 'CelebA':
        transform = transforms.Compose([
            transforms.CenterCrop(160),
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        image_path = os.path.join(image_path, dataset)
        if not os.path.exists(os.path.join(image_path, dataset)):
            os.makedirs(image_path)
        dataset = datasets.ImageFolder(image_path, transform=transform)
    elif dataset == 'mnist':
        dataset = datasets.MNIST('../../data/mnist',
                                 train=True,
                                 download=True,
                                 transform=transforms.Compose([
                                     transforms.Pad(2),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, ), (0.5, ))
                                 ]))
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=2,
                                              drop_last=True)

    return data_loader
Beispiel #13
0
def dataloader(dataset, input_size, batch_size, data_root="data", split='train'):
    transform = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
    if dataset == 'mnist':
        data_path = os.path.join(data_root, "mnist")
        data_loader = DataLoader(
            datasets.MNIST(data_path, train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'fashion-mnist':
        data_path = os.path.join(data_root, "fashion-mnist")
        data_loader = DataLoader(
            datasets.FashionMNIST(data_path, train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'cifar10':
        data_path = os.path.join(data_root, "cifar10")
        data_loader = DataLoader(
            datasets.CIFAR10(data_path, train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'svhn':
        data_path = os.path.join(data_root, "svhn")
        data_loader = DataLoader(
            datasets.SVHN(data_path, split=split, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'stl10':
        data_path = os.path.join(data_root, "stl10")
        data_loader = DataLoader(
            datasets.STL10(data_path, split=split, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'lsun-bed':
        data_path = os.path.join(data_root, "lsun")
        data_loader = DataLoader(
            datasets.LSUN(data_path, classes=['bedroom_train'], transform=transform),
            batch_size=batch_size, shuffle=True)

    return data_loader
Beispiel #14
0
def compute_dataset_statistics(target_set="LSUN", batch_size=50, dims=2048, cuda=True):
    if target_set == "CIFAR10":
        imageSize = 64
        dataset = datasets.CIFAR10(root="~/datasets/data_cifar10", train=False, download=True,
                        transform=transforms.Compose([
                        transforms.Resize(imageSize),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                        ]))
    elif target_set == "MNIST":
        imageSize = 64
        dataset = datasets.MNIST(root="~/datasets", train=True, download=True, 
                        transform=transforms.Compose([
                        transforms.Resize(imageSize),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,)),
                        ]))
    elif target_set == "LSUN":
        imageSize = 128
        dataset = datasets.LSUN(root="/data1/zhangliangyu/datasets/data_lsun", classes=["church_outdoor_train"],
                        transform=transforms.Compose([
                            transforms.Resize(imageSize),
                            transforms.CenterCrop(imageSize),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                        ]))


    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx])
    if cuda:
        model.cuda()

    model.eval()
    pred_arr = np.empty((len(dataset), dims))
    start = 0

    print("Computing statistics of the given dataset...")

    for (x, y) in tqdm(data_loader):
        if target_set == "MNIST":
            tmp = torch.zeros((x.size()[0], 3, x.size()[2], x.size()[3]))
            for i in range(3):
                tmp[:, i, :, :] = x[:, 0, :, :]
            x = tmp
        end = start + x.size(0)
        if cuda:
            x = x.cuda()
        pred = model(x)[0]
        if pred.size(2) != 1 or pred.size(3) != 1:
            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
        pred_arr[start:end] = pred.cpu().data.numpy().reshape(pred.size(0), -1)
        start = end
        
    
    mu = np.mean(pred_arr, axis=0)
    sigma = np.cov(pred_arr, rowvar=False)
    return mu, sigma
Beispiel #15
0
def dataloader(dataset, input_size, batch_size, split='train'):
    transform = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
    if dataset == 'mnist':
        transform = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])])
        data_loader = DataLoader(
            datasets.MNIST('data/mnist', train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'fashion-mnist':
        data_loader = DataLoader(
            datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'cifar10':
        data_loader = DataLoader(
            datasets.CIFAR10('data/cifar10', train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'svhn':
        data_loader = DataLoader(
            datasets.SVHN('data/svhn', split=split, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'stl10':
        data_loader = DataLoader(
            datasets.STL10('data/stl10', split=split, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'lsun-bed':
        data_loader = DataLoader(
            datasets.LSUN('data/lsun', classes=['bedroom_train'], transform=transform),
            batch_size=batch_size, shuffle=True)

    return data_loader
Beispiel #16
0
def sample_data(path, batch_size, image_size, dataset):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (1, 1, 1)),
    ])

    if dataset == 'cifar10':
        dataset = datasets.CIFAR10(path, transform=transform)
    elif dataset == 'lsun':
        dataset = datasets.LSUN(path,
                                classes=['church_outdoor_train'],
                                transform=transform)
    else:
        dataset = datasets.ImageFolder(path, transform=transform)
    loader = DataLoader(dataset,
                        shuffle=True,
                        batch_size=batch_size,
                        num_workers=4)
    loader = iter(loader)

    while True:
        try:
            yield next(loader)

        except StopIteration:
            loader = DataLoader(dataset,
                                shuffle=True,
                                batch_size=batch_size,
                                num_workers=4)
            loader = iter(loader)
            yield next(loader)
Beispiel #17
0
def data_loader():
    kwopt = {'num_workers': 2, 'pin_memory': True} if opt.cuda else {}

    if opt.dataset == 'lsun':
        train_dataset = datasets.LSUN(db_path=opt.datapath + 'train/', classes=['bedroom_train'],
                                      transform=transforms.Compose([
                                          transforms.Scale(opt.image_size),
                                          transforms.CenterCrop(opt.image_size),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                      ]))
    elif opt.dataset == 'mnist':
        train_dataset = datasets.MNIST('./data', train=True, download=True,
                                       transform=transforms.Compose([
                                           transforms.Scale(opt.image_size),
                                           transforms.CenterCrop(opt.image_size),
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                       ]))
        test_dataset = datasets.MNIST('./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]))

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, **kwopt)

    # test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, **kwopt)

    return train_loader
def LSUN_loader(root, image_size, classes=['bedroom'], normalize=True):
    """
        Function to load torchvision dataset object based on just image size
        Args:
            root = If your dataset is downloaded and ready to use, mention the location of this folder.
                   Else, the dataset will be downloaded to this location
            image_size = Size of every image
            classes = Default class is 'bedroom'. Other available classes are:
                      'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room',
                      'kitchen', 'living_room', 'restaurant', 'tower'
            normalize = Requirement to normalize the image. Default is true
    """
    transformations = [
        transforms.Scale(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor()
    ]
    if normalize:
        transformations.append(
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
    for c in classes:
        c = c + '_train'
    lsun_data = dset.LSUN(db_path=root,
                          classes=classes,
                          transform=transforms.Compose(transformations))
    return lsun_data
Beispiel #19
0
def __getDataSet(opt):
    if isDebug: print(f"Getting dataset: {opt.dataset} ... ")

    dataset = None
    if opt.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=opt.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Scale(opt.imageSize),
                                       transforms.CenterCrop(opt.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == 'lsun':
        dataset = dset.LSUN(db_path=opt.dataroot,
                            classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Scale(opt.imageSize),
                                transforms.CenterCrop(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif opt.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=opt.dataroot,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))

    return dataset
def get_dataset(name,
                data_dir,
                size=64,
                lsun_categories=None,
                deterministic=False,
                transform=None):

    transform = transforms.Compose([
        t for t in [
            transforms.Resize(size),
            transforms.CenterCrop(size),
            (not deterministic) and transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            (not deterministic) and transforms.Lambda(
                lambda x: x + 1. / 128 * torch.rand(x.size())),
        ] if t is not False
    ]) if transform == None else transform

    if name == 'image':
        print('Using image labels')
        dataset = datasets.ImageFolder(data_dir, transform)
        nlabels = len(dataset.classes)
    elif name == 'webp':
        print('Using no labels from webp')
        dataset = CachedImageFolder(data_dir, transform)
        nlabels = len(dataset.classes)
    elif name == 'npy':
        # Only support normalization for now
        dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
        nlabels = len(dataset.classes)
    elif name == 'cifar10':
        dataset = datasets.CIFAR10(root=data_dir,
                                   train=True,
                                   download=True,
                                   transform=transform)
        nlabels = 10
    elif name == 'stacked_mnist':
        dataset = StackedMNIST(data_dir,
                               transform=transforms.Compose([
                                   transforms.Resize(size),
                                   transforms.CenterCrop(size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, ), (0.5, ))
                               ]))
        nlabels = 1000
    elif name == 'lsun':
        if lsun_categories is None:
            lsun_categories = 'train'
        dataset = datasets.LSUN(data_dir, lsun_categories, transform)
        nlabels = len(dataset.classes)
    elif name == 'lsun_class':
        dataset = datasets.LSUNClass(data_dir,
                                     transform,
                                     target_transform=(lambda t: 0))
        nlabels = 1
    else:
        raise NotImplemented
    return dataset, nlabels
Beispiel #21
0
def create_dataset(dataset_name, root=data_root, image_size=None):
    """Create dataset given dataset name. 

    args:
        dataset_name (str): dataset name

        image_size(int): output image size. Only for mnist, it's None

    returns:
        torch.util.data.dataset
    
    """
    path = os.path.join(root, dataset_name)
    if not os.path.exists(path):
        os.makedirs(path)

    if dataset_name == 'mnist':
        dataset = dset.MNIST(path,
                             train=True,
                             download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.1307, ), (0.3081, ))
                             ]))
    elif dataset_name == 'cifar10':
        dataset = dset.CIFAR10(root=path,
                               download=True,
                               transform=transforms.Compose([
                                   resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    elif dataset_name == 'cifar100':
        dataset = dset.CIFAR10(root=path,
                               download=True,
                               transform=transforms.Compose([
                                   resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    elif dataset_name == 'lsun':
        dataset = dset.LSUN(db_path=path,
                            classes=['bedroom_train'],
                            transform=transforms.Compose([
                                resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif dataset_name == 'fake':
        dataset = dset.FakeData(image_size=(3, image_size, image_size),
                                transform=transforms.ToTensor())
    else:
        raise "Unknown dataset name: ", dataset_name

    return dataset
Beispiel #22
0
    def __init__(self, options):
        transform_list = []
        if options.image_size is not None:
            transform_list.append(
                transforms.Resize((options.image_size, options.image_size)))
            # transform_list.append(transforms.CenterCrop(options.image_size))
        transform_list.append(transforms.ToTensor())
        if options.image_colors == 1:
            transform_list.append(transforms.Normalize(mean=[0.5], std=[0.5]))
        elif options.image_colors == 3:
            transform_list.append(
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5,
                                                                0.5]))
        transform = transforms.Compose(transform_list)

        if options.dataset == 'mnist':
            dataset = datasets.MNIST(options.data_dir,
                                     train=True,
                                     download=True,
                                     transform=transform)
        elif options.dataset == 'emnist':
            # Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
            datasets.EMNIST.url = 'https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download'
            dataset = datasets.EMNIST(options.data_dir,
                                      split=options.image_class,
                                      train=True,
                                      download=True,
                                      transform=transform)
        elif options.dataset == 'fashion-mnist':
            dataset = datasets.FashionMNIST(options.data_dir,
                                            train=True,
                                            download=True,
                                            transform=transform)
        elif options.dataset == 'lsun':
            training_class = options.image_class + '_train'
            dataset = datasets.LSUN(options.data_dir,
                                    classes=[training_class],
                                    transform=transform)
        elif options.dataset == 'cifar10':
            dataset = datasets.CIFAR10(options.data_dir,
                                       train=True,
                                       download=True,
                                       transform=transform)
        elif options.dataset == 'cifar100':
            dataset = datasets.CIFAR100(options.data_dir,
                                        train=True,
                                        download=True,
                                        transform=transform)
        else:
            dataset = datasets.ImageFolder(root=options.data_dir,
                                           transform=transform)

        self.dataloader = DataLoader(dataset,
                                     batch_size=options.batch_size,
                                     num_workers=options.loader_workers,
                                     shuffle=True,
                                     drop_last=True,
                                     pin_memory=options.pin_memory)
        self.iterator = iter(self.dataloader)
Beispiel #23
0
def gen_lsun_balanced(dataroot, nms, tr, indexes):
    sub_dss = []
    for i, nm in enumerate(nms):
        sub_dss += [
            Subset(datasets.LSUN(dataroot, classes=[nm], transform=tr),
                   indexes)
        ]
    return ConcatUniClassDataset(sub_dss)
Beispiel #24
0
 def load_lsun(self, classes='bedroom_train'):  # church_outdoor_train
     transforms = self.transform(True, True, True, False)
     if self.image_path is None:
         raise NotImplementedError("LSUN is not downloaded.")
     dataset = dsets.LSUN(self.image_path,
                          classes=[classes],
                          transform=transforms)
     return dataset
Beispiel #25
0
def dataloader(dataset, input_size, batch_size, split='train'):
    transform = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    if dataset == 'mnist':
        data_loader = DataLoader(datasets.MNIST('data/mnist',
                                                train=True,
                                                download=True,
                                                transform=transform),
                                 batch_size=batch_size,
                                 shuffle=True)
    elif dataset == 'pickle':
        #features,attackers,defenders=load_data('data')
        #result=merge(features,defenders,attackers)
        dataset = data.generate_random()
        data_loader = DataLoader(dataset)
    elif dataset == 'fashion-mnist':
        data_loader = DataLoader(datasets.FashionMNIST('data/fashion-mnist',
                                                       train=True,
                                                       download=True,
                                                       transform=transform),
                                 batch_size=batch_size,
                                 shuffle=True)
    elif dataset == 'cifar10':
        data_loader = DataLoader(datasets.CIFAR10('data/cifar10',
                                                  train=True,
                                                  download=True,
                                                  transform=transform),
                                 batch_size=batch_size,
                                 shuffle=True)
    elif dataset == 'svhn':
        data_loader = DataLoader(datasets.SVHN('data/svhn',
                                               split=split,
                                               download=True,
                                               transform=transform),
                                 batch_size=batch_size,
                                 shuffle=True)
    elif dataset == 'stl10':
        data_loader = DataLoader(datasets.STL10('data/stl10',
                                                split=split,
                                                download=True,
                                                transform=transform),
                                 batch_size=batch_size,
                                 shuffle=True)
    elif dataset == 'lsun-bed':
        data_loader = DataLoader(datasets.LSUN('data/lsun',
                                               classes=['bedroom_train'],
                                               transform=transform),
                                 batch_size=batch_size,
                                 shuffle=True)
    elif dataset == 'pier':
        data_loader = DataLoader(datasets.ImageFolder('data/pier',
                                                      transform=transform),
                                 batch_size=batch_size,
                                 shuffle=True)
    return data_loader
Beispiel #26
0
def get_dataset(name,
                data_dir,
                size=64,
                lsun_categories=None,
                load_in_mem=False):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())),
    ])
    data_dir = os.path.expanduser(data_dir)
    if name == 'image':
        dataset = datasets.ImageFolder(data_dir, transform)
        nlabels = len(dataset.classes)
    elif name == 'hdf5':
        from TOOLS.make_hdf5 import Dataset_HDF5
        transform = transforms.Compose([
            transforms.Lambda(lambda x: x.transpose(1, 2, 0)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())),
        ])
        dataset = Dataset_HDF5(root=data_dir,
                               transform=transform,
                               load_in_mem=load_in_mem)
        nlabels = len(dataset.classes)
    elif name == 'npy':
        # Only support normalization for now
        dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
        nlabels = len(dataset.classes)
    elif name == 'cifar10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        dataset = datasets.CIFAR10(root=data_dir,
                                   train=True,
                                   download=True,
                                   transform=transform)
        nlabels = 10
    elif name == 'lsun':
        if lsun_categories is None:
            lsun_categories = 'train'
        dataset = datasets.LSUN(data_dir, lsun_categories, transform)
        nlabels = len(dataset.classes)
    elif name == 'lsun_class':
        dataset = datasets.LSUNClass(data_dir,
                                     transform,
                                     target_transform=(lambda t: 0))
        nlabels = 1
    else:
        raise NotImplemented

    return dataset, nlabels
Beispiel #27
0
def load_lsun128(data_path, category):
    imageSize = 128
    train_data = datasets.LSUN(data_path,
                               classes=[category + '_train'],
                               transform=transforms.Compose([
                                   transforms.CenterCrop(256),
                                   transforms.Resize(imageSize),
                                   transforms.ToTensor(),
                               ]))

    val_data = datasets.LSUN(data_path,
                             classes=[category + '_val'],
                             transform=transforms.Compose([
                                 transforms.CenterCrop(256),
                                 transforms.Resize(imageSize),
                                 transforms.ToTensor(),
                             ]))
    return train_data, val_data
Beispiel #28
0
def check_dataset(dataset, dataroot):
    """

    Args:
        dataset (str): Name of the dataset to use. See CLI help for details
        dataroot (str): root directory where the dataset will be stored.

    Returns:
        dataset (data.Dataset): torchvision Dataset object

    """
    resize = transforms.Resize(64)
    crop = transforms.CenterCrop(64)
    to_tensor = transforms.ToTensor()
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    if dataset in {"imagenet", "folder", "lfw"}:
        dataset = dset.ImageFolder(
            root=dataroot,
            transform=transforms.Compose([resize, crop, to_tensor, normalize]),
        )
        nc = 3

    elif dataset == "lsun":
        dataset = dset.LSUN(
            root=dataroot,
            classes=["bedroom_train"],
            transform=transforms.Compose([resize, crop, to_tensor, normalize]),
        )
        nc = 3

    elif dataset == "cifar10":
        dataset = dset.CIFAR10(
            root=dataroot,
            download=True,
            transform=transforms.Compose([resize, to_tensor, normalize]),
        )
        nc = 3

    elif dataset == "mnist":
        dataset = dset.MNIST(
            root=dataroot,
            download=True,
            transform=transforms.Compose([resize, to_tensor, normalize]),
        )
        nc = 1

    elif dataset == "fake":
        dataset = dset.FakeData(size=256,
                                image_size=(3, 64, 64),
                                transform=to_tensor)
        nc = 3

    else:
        raise RuntimeError("Invalid dataset name: {}".format(dataset))

    return dataset, nc
Beispiel #29
0
def load_lsun64(data_path, category):
    imageSize = 64
    train_data = datasets.LSUN(data_path,
                               classes=[category + '_train'],
                               transform=transforms.Compose([
                                   transforms.Resize(96),
                                   transforms.RandomCrop(imageSize),
                                   transforms.ToTensor(),
                               ]))

    val_data = datasets.LSUN(data_path,
                             classes=[category + '_val'],
                             transform=transforms.Compose([
                                 transforms.Resize(96),
                                 transforms.RandomCrop(imageSize),
                                 transforms.ToTensor(),
                             ]))
    return train_data, val_data
Beispiel #30
0
def get_loader(_dataset, dataroot, batch_size, num_workers, image_size):
    # folder dataset
    if _dataset in ['imagenet', 'folder', 'lfw']:
        dataroot += '/resized_celebA'
        dataset = dset.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]))
    elif _dataset == 'lsun':
        dataroot += '/lsun'
        dataset = dset.LSUN(db_path=dataroot, classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))

    elif _dataset == 'cifar10':
        dataroot += '/cifar10'
        dataset = dset.CIFAR10(root=dataroot, download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))

    elif _dataset == 'mnist':
        dataroot += '/mnist'
        dataset = dset.MNIST(root=dataroot, download=True,
                             transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))

    elif _dataset == 'fashion':
        dataroot += '/fashion'
        dataset = dset.FashionMNIST(root=dataroot, download=True,
                                    transform=transforms.Compose([
                                        transforms.Resize(image_size),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                    ]))

    elif _dataset == 'fake':
        dataroot += '/fake'
        dataset = dset.FakeData(image_size=(3, image_size, image_size),
                                transform=transforms.ToTensor())

    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=num_workers)

    return dataloader