Esempio n. 1
0
def get_dataset(name, data_dir, size=64, lsun_categories=None):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())),
    ])

    if name == 'image':
        dataset = datasets.ImageFolder(data_dir, transform)
    elif name == 'npy':
        # Only support normalization for now
        dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
    elif name == 'cifar10':
        dataset = datasets.CIFAR10(root=data_dir,
                                   train=True,
                                   download=True,
                                   transform=transform)
    elif name == 'lsun':
        if lsun_categories is None:
            lsun_categories = 'train'
        dataset = datasets.LSUN(data_dir, lsun_categories, transform)
    elif name == 'lsun_class':
        dataset = datasets.LSUNClass(data_dir,
                                     transform,
                                     target_transform=(lambda t: 0))
    else:
        raise NotImplemented

    return dataset
def get_dataset(name,
                data_dir,
                size=64,
                lsun_categories=None,
                deterministic=False,
                transform=None):

    transform = transforms.Compose([
        t for t in [
            transforms.Resize(size),
            transforms.CenterCrop(size),
            (not deterministic) and transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            (not deterministic) and transforms.Lambda(
                lambda x: x + 1. / 128 * torch.rand(x.size())),
        ] if t is not False
    ]) if transform == None else transform

    if name == 'image':
        print('Using image labels')
        dataset = datasets.ImageFolder(data_dir, transform)
        nlabels = len(dataset.classes)
    elif name == 'webp':
        print('Using no labels from webp')
        dataset = CachedImageFolder(data_dir, transform)
        nlabels = len(dataset.classes)
    elif name == 'npy':
        # Only support normalization for now
        dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
        nlabels = len(dataset.classes)
    elif name == 'cifar10':
        dataset = datasets.CIFAR10(root=data_dir,
                                   train=True,
                                   download=True,
                                   transform=transform)
        nlabels = 10
    elif name == 'stacked_mnist':
        dataset = StackedMNIST(data_dir,
                               transform=transforms.Compose([
                                   transforms.Resize(size),
                                   transforms.CenterCrop(size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, ), (0.5, ))
                               ]))
        nlabels = 1000
    elif name == 'lsun':
        if lsun_categories is None:
            lsun_categories = 'train'
        dataset = datasets.LSUN(data_dir, lsun_categories, transform)
        nlabels = len(dataset.classes)
    elif name == 'lsun_class':
        dataset = datasets.LSUNClass(data_dir,
                                     transform,
                                     target_transform=(lambda t: 0))
        nlabels = 1
    else:
        raise NotImplemented
    return dataset, nlabels
Esempio n. 3
0
        def loader(transform, batch_size):
            data = datasets.LSUNClass(
                path, transform=transform,
                target_transform=lambda x: 0)
            data_loader = DataLoader(data, shuffle=True, batch_size=batch_size,
                                    num_workers=4, pin_memory=(args.gpu_count>1))

            return data_loader
Esempio n. 4
0
    def loader(transform):
        data = datasets.LSUNClass(
            path, transform=transform,
            target_transform=lambda x: 0)
        data_loader = DataLoader(data, shuffle=False, batch_size=batch_size,
                                 num_workers=4)

        return data_loader
Esempio n. 5
0
def get_dataset(name,
                data_dir,
                size=64,
                lsun_categories=None,
                load_in_mem=False):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())),
    ])
    data_dir = os.path.expanduser(data_dir)
    if name == 'image':
        dataset = datasets.ImageFolder(data_dir, transform)
        nlabels = len(dataset.classes)
    elif name == 'hdf5':
        from TOOLS.make_hdf5 import Dataset_HDF5
        transform = transforms.Compose([
            transforms.Lambda(lambda x: x.transpose(1, 2, 0)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())),
        ])
        dataset = Dataset_HDF5(root=data_dir,
                               transform=transform,
                               load_in_mem=load_in_mem)
        nlabels = len(dataset.classes)
    elif name == 'npy':
        # Only support normalization for now
        dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
        nlabels = len(dataset.classes)
    elif name == 'cifar10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        dataset = datasets.CIFAR10(root=data_dir,
                                   train=True,
                                   download=True,
                                   transform=transform)
        nlabels = 10
    elif name == 'lsun':
        if lsun_categories is None:
            lsun_categories = 'train'
        dataset = datasets.LSUN(data_dir, lsun_categories, transform)
        nlabels = len(dataset.classes)
    elif name == 'lsun_class':
        dataset = datasets.LSUNClass(data_dir,
                                     transform,
                                     target_transform=(lambda t: 0))
        nlabels = 1
    else:
        raise NotImplemented

    return dataset, nlabels
Esempio n. 6
0
def get_dataset(name, data_dir, size=64, lsun_categories=None, config=None):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())),
    ])

    if name == "MoG":
        dataset = MixtureOfGaussianDataset(config)
        nlabels = 1
    elif name.lower() == "celeba":
        imgs = np.load("/home/LargeData/celebA_64x64.npy")
        labels = np.zeros([imgs.shape[0]]).astype(np.int64)
        dataset = NumpyImageDataset(imgs, labels, transform)
        nlabels = 1
    elif name == 'image':
        dataset = datasets.ImageFolder(data_dir, transform)
        nlabels = len(dataset.classes)
    elif name == 'npy':
        # Only support normalization for now
        dataset = datasets.DatasetFolder(data_dir, npy_loader, 'npy')
        nlabels = len(dataset.classes)
    elif name == 'cifar10':
        dataset = datasets.CIFAR10(root=data_dir,
                                   train=True,
                                   download=True,
                                   transform=transform)
        nlabels = 10
    elif name == 'lsun':
        if lsun_categories is None:
            lsun_categories = 'train'
        dataset = datasets.LSUN(data_dir, lsun_categories, transform)
        nlabels = len(dataset.classes)
    elif name == 'lsun_class':
        dataset = datasets.LSUNClass(data_dir,
                                     transform,
                                     target_transform=(lambda t: 0))
        nlabels = 1
    else:
        raise NotImplementedError

    return dataset, nlabels
Esempio n. 7
0
    g_optimizer.add_param_group(
        {
            'params': generator.module.style.parameters(),
            'lr': args.lr * 0.01,
            'mult': 0.01,
        }
    )
    d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99))

    accumulate(g_running, generator.module, 0)

    if args.data == 'folder':
        dataset = datasets.ImageFolder(args.path)

    elif args.data == 'lsun':
        dataset = datasets.LSUNClass(args.path, target_transform=lambda x: 0)

    if args.sched:
        args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        args.batch = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32}

    else:
        args.lr = {}
        args.batch = {}

    args.gen_sample = {512: (8, 4), 1024: (4, 2)}

    args.batch_default = 32

    train(args, dataset, generator, discriminator)
Esempio n. 8
0
                             transform=transforms.Compose(
                                 _CIFAR_TRAIN_TRANSFORMS)),
    'cifar100':
    lambda: datasets.CIFAR100('./datasets/cifar100',
                              train=True,
                              download=True,
                              transform=transforms.Compose(
                                  _CIFAR_TRAIN_TRANSFORMS)),
    'shvn':
    lambda: datasets.SVHN('./datasets/shvn',
                          download=True,
                          split='train',
                          transform=transforms.Compose(_SHVN_TEST_TRANSFORMS)),
    'lsun':
    lambda: datasets.LSUNClass('./datasets/lsun/bedroom_train',
                               transform=transforms.Compose(
                                   _LSUN_TRAIN_TRANSFORMS))
}

TEST_DATASETS = {
    'mnist':
    lambda: datasets.MNIST('./datasets/mnist',
                           train=False,
                           transform=transforms.Compose(_MNIST_TEST_TRANSFORMS)
                           ),
    'cifar10':
    lambda: datasets.CIFAR10('./datasets/cifar10',
                             train=False,
                             transform=transforms.Compose(
                                 _CIFAR_TEST_TRANSFORMS)),
    'cifar100':
Esempio n. 9
0
def get_dataset(name, data_dir, train=True, size=64, lsun_categories=None):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())),
    ])

    transformMnist = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])

    if name == 'image':
        dataset = datasets.ImageFolder(data_dir, transform)
        nlabels = len(dataset.classes)
    elif name == 'npy':
        # Only support normalization for now
        dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
        nlabels = len(dataset.classes)
    elif name == 'imagenet32':
        dataset = ImageNetDataset(dir=data_dir, img_size=size)
        nlabels = 1000
    elif name == 'imagenet64':
        dataset = ImageNetDataset(dir=data_dir, img_size=size)
        nlabels = 1000
    elif name == 'mnist':
        dataset = datasets.MNIST(root=data_dir, train=train, download=True,
                                 transform=transformMnist)
        nlabels = 10
    elif name == 'cifar10':
        if train:
            transform_train = transforms.Compose([
                transforms.Resize(size),
                transforms.RandomCrop(size, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
            dataset = datasets.CIFAR10(root=data_dir, train=train, download=True,
                                       transform=transform_train)
        else:
            transform_test = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
            dataset = datasets.CIFAR10(root=data_dir, train=train, download=True,
                                       transform=transform_test)
        nlabels = 10
    elif name == 'fashionmnist':
        dataset = datasets.FashionMNIST(root=data_dir, train=train, download=True,
                                        transform=transformMnist)
        nlabels = 10
    elif name == 'lsun':
        if lsun_categories is None:
            lsun_categories = 'train'
        dataset = datasets.LSUN(data_dir, lsun_categories, transform)
        nlabels = len(dataset.classes)
    elif name == 'lsun_class':
        dataset = datasets.LSUNClass(data_dir, transform,
                                     target_transform=(lambda t: 0))
        nlabels = 1
    elif name == 'celeba':
        dataset = datasets.ImageFolder(root=data_dir, transform=transform)
        nlabels = 1
    else:
        raise NotImplemented

    return dataset, nlabels