def get_dataset(args):
    trans = lambda im_size: tforms.Compose([tforms.Resize(im_size)])

    if args.data == "mnist":
        im_dim = 1
        im_size = 28 if args.imagesize is None else args.imagesize
        train_set = dset.MNIST(root=args.datadir, train=True, transform=trans(im_size), download=True)
        test_set = dset.MNIST(root=args.datadir, train=False, transform=trans(im_size), download=True)
    elif args.data == "svhn":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.SVHN(root=args.datadir, split="train", transform=trans(im_size), download=True)
        test_set = dset.SVHN(root=args.datadir, split="test", transform=trans(im_size), download=True)
    elif args.data == "cifar10":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.CIFAR10(
            root=args.datadir, train=True, transform=tforms.Compose([
                tforms.Resize(im_size),
                tforms.RandomHorizontalFlip(),
            ]), download=True
        )
        test_set = dset.CIFAR10(root=args.datadir, train=False, transform=None, download=True)
    elif args.data == 'celebahq':
        im_dim = 3
        im_size = 256 if args.imagesize is None else args.imagesize
        ''' 
        train_set = CelebAHQ(
            train=True, root=args.datadir, transform=tforms.Compose([
                tforms.ToPILImage(),
                tforms.Resize(im_size),
                tforms.RandomHorizontalFlip(),
            ])
        )
        test_set = CelebAHQ(
            train=False, root=args.datadir,  transform=tforms.Compose([
                tforms.ToPILImage(),
                tforms.Resize(im_size),
            ])
        )
        '''
        train_set = CelebAHQ(train=True, root=args.datadir)
        test_set = CelebAHQ(train=False, root=args.datadir)
    elif args.data == 'imagenet64':
        im_dim = 3
        if args.imagesize != 64:
            args.imagesize = 64
        im_size = 64
        train_set = Imagenet64(train=True, root=args.datadir)
        test_set = Imagenet64(train=False, root=args.datadir)
    elif args.data == 'lsun_church':
        im_dim = 3
        im_size = 64 if args.imagesize is None else args.imagesize
        train_set = dset.LSUN(
            'data', ['church_outdoor_train'], transform=tforms.Compose([
                tforms.Resize(96),
                tforms.RandomCrop(64),
                tforms.Resize(im_size),
            ])
        )
        test_set = dset.LSUN(
            'data', ['church_outdoor_val'], transform=tforms.Compose([
                tforms.Resize(96),
                tforms.RandomCrop(64),
                tforms.Resize(im_size),
            ])
        )
    data_shape = (im_dim, im_size, im_size)

    def fast_collate(batch):

        imgs = [img[0] for img in batch]
        targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
        w = imgs[0].size[0]
        h = imgs[0].size[1]

        tensor = torch.zeros((len(imgs), im_dim, im_size, im_size), dtype=torch.uint8)
        for i, img in enumerate(imgs):
            nump_array = np.asarray(img, dtype=np.uint8)
            tens = torch.from_numpy(nump_array)
            if (nump_array.ndim < 3):
                nump_array = np.expand_dims(nump_array, axis=-1)
            nump_array = np.rollaxis(nump_array, 2)
            tensor[i] += torch.from_numpy(nump_array)

        return tensor, targets

    train_sampler = (DistributedSampler(train_set,
                                        num_replicas=env_world_size(), rank=env_rank()) if args.distributed
                     else None)

    if not args.distributed:
        train_loader = torch.utils.data.DataLoader(
            dataset=train_set, batch_size=args.batch_size, shuffle=True,
            num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate
        )
    else:
        train_loader = torch.utils.data.DataLoader(
            dataset=train_set, batch_size=args.batch_size, sampler=train_sampler,
            num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate
        )

    # import pdb
    # pdb.set_trace()

    test_sampler = (DistributedSampler(test_set,
                                       num_replicas=env_world_size(), rank=env_rank()) if args.distributed
                    else None)

    if not args.distributed:
        test_loader = torch.utils.data.DataLoader(
            dataset=test_set, batch_size=args.test_batch_size, shuffle=False,
            num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate
        )
    else:
        test_loader = torch.utils.data.DataLoader(
            dataset=test_set, batch_size=args.test_batch_size,
            num_workers=args.nworkers, pin_memory=True, sampler=test_sampler, collate_fn=fast_collate
        )
    return train_loader, test_loader, data_shape
예제 #2
0
def get_dataset(args):
    trans = lambda im_size: tforms.Compose(
        [tforms.Resize(im_size), tforms.ToTensor()])

    if args.data == "mnist":
        im_dim = 1
        im_size = 28 if args.imagesize is None else args.imagesize
        train_set = dset.MNIST(root="./data",
                               train=True,
                               transform=trans(im_size),
                               download=True)
        test_set = dset.MNIST(root="./data",
                              train=False,
                              transform=trans(im_size),
                              download=True)
    elif args.data == "cifar10":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.CIFAR10(root="./data",
                                 train=True,
                                 transform=tforms.Compose([
                                     tforms.Resize(im_size),
                                     tforms.RandomHorizontalFlip(),
                                     tforms.ToTensor(),
                                 ]),
                                 download=True)
        test_set = dset.CIFAR10(root="./data",
                                train=False,
                                transform=trans(im_size),
                                download=True)
    elif args.data == 'imagenet64':
        im_dim = 3
        if args.imagesize != 64:
            args.imagesize = 64
        im_size = 64
        train_set = Imagenet64(train=True,
                               root='/mnt/data/scratch/data/',
                               transform=tforms.ToTensor())
        test_set = Imagenet64(train=False,
                              root='/mnt/data/scratch/data/',
                              transform=tforms.ToTensor())
    elif args.data == 'celebahq':
        im_dim = 3
        im_size = 256 if args.imagesize is None else args.imagesize
        train_set = CelebAHQ(train=True,
                             root='/mnt/data/scratch/data/',
                             transform=tforms.Compose([
                                 tforms.ToPILImage(),
                                 tforms.Resize(im_size),
                                 tforms.RandomHorizontalFlip(),
                                 tforms.ToTensor()
                             ]))
        test_set = CelebAHQ(train=False,
                            root='/mnt/data/scratch/data/',
                            transform=tforms.Compose([
                                tforms.ToPILImage(),
                                tforms.Resize(im_size),
                                tforms.ToTensor()
                            ]))
    data_shape = (im_dim, im_size, im_size)

    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              drop_last=True)
    return train_set, test_loader, data_shape