def get_dataset(args): trans = lambda im_size: tforms.Compose([tforms.Resize(im_size)]) if args.data == "mnist": im_dim = 1 im_size = 28 if args.imagesize is None else args.imagesize train_set = dset.MNIST(root=args.datadir, train=True, transform=trans(im_size), download=True) test_set = dset.MNIST(root=args.datadir, train=False, transform=trans(im_size), download=True) elif args.data == "svhn": im_dim = 3 im_size = 32 if args.imagesize is None else args.imagesize train_set = dset.SVHN(root=args.datadir, split="train", transform=trans(im_size), download=True) test_set = dset.SVHN(root=args.datadir, split="test", transform=trans(im_size), download=True) elif args.data == "cifar10": im_dim = 3 im_size = 32 if args.imagesize is None else args.imagesize train_set = dset.CIFAR10( root=args.datadir, train=True, transform=tforms.Compose([ tforms.Resize(im_size), tforms.RandomHorizontalFlip(), ]), download=True ) test_set = dset.CIFAR10(root=args.datadir, train=False, transform=None, download=True) elif args.data == 'celebahq': im_dim = 3 im_size = 256 if args.imagesize is None else args.imagesize ''' train_set = CelebAHQ( train=True, root=args.datadir, transform=tforms.Compose([ tforms.ToPILImage(), tforms.Resize(im_size), tforms.RandomHorizontalFlip(), ]) ) test_set = CelebAHQ( train=False, root=args.datadir, transform=tforms.Compose([ tforms.ToPILImage(), tforms.Resize(im_size), ]) ) ''' train_set = CelebAHQ(train=True, root=args.datadir) test_set = CelebAHQ(train=False, root=args.datadir) elif args.data == 'imagenet64': im_dim = 3 if args.imagesize != 64: args.imagesize = 64 im_size = 64 train_set = Imagenet64(train=True, root=args.datadir) test_set = Imagenet64(train=False, root=args.datadir) elif args.data == 'lsun_church': im_dim = 3 im_size = 64 if args.imagesize is None else args.imagesize train_set = dset.LSUN( 'data', ['church_outdoor_train'], transform=tforms.Compose([ tforms.Resize(96), tforms.RandomCrop(64), tforms.Resize(im_size), ]) ) test_set = dset.LSUN( 'data', ['church_outdoor_val'], transform=tforms.Compose([ tforms.Resize(96), tforms.RandomCrop(64), tforms.Resize(im_size), ]) ) data_shape = (im_dim, im_size, im_size) def fast_collate(batch): imgs = [img[0] for img in batch] targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) w = imgs[0].size[0] h = imgs[0].size[1] tensor = torch.zeros((len(imgs), im_dim, im_size, im_size), dtype=torch.uint8) for i, img in enumerate(imgs): nump_array = np.asarray(img, dtype=np.uint8) tens = torch.from_numpy(nump_array) if (nump_array.ndim < 3): nump_array = np.expand_dims(nump_array, axis=-1) nump_array = np.rollaxis(nump_array, 2) tensor[i] += torch.from_numpy(nump_array) return tensor, targets train_sampler = (DistributedSampler(train_set, num_replicas=env_world_size(), rank=env_rank()) if args.distributed else None) if not args.distributed: train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate ) else: train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate ) # import pdb # pdb.set_trace() test_sampler = (DistributedSampler(test_set, num_replicas=env_world_size(), rank=env_rank()) if args.distributed else None) if not args.distributed: test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=args.test_batch_size, shuffle=False, num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate ) else: test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=args.test_batch_size, num_workers=args.nworkers, pin_memory=True, sampler=test_sampler, collate_fn=fast_collate ) return train_loader, test_loader, data_shape
def get_dataset(args): trans = lambda im_size: tforms.Compose( [tforms.Resize(im_size), tforms.ToTensor()]) if args.data == "mnist": im_dim = 1 im_size = 28 if args.imagesize is None else args.imagesize train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True) test_set = dset.MNIST(root="./data", train=False, transform=trans(im_size), download=True) elif args.data == "cifar10": im_dim = 3 im_size = 32 if args.imagesize is None else args.imagesize train_set = dset.CIFAR10(root="./data", train=True, transform=tforms.Compose([ tforms.Resize(im_size), tforms.RandomHorizontalFlip(), tforms.ToTensor(), ]), download=True) test_set = dset.CIFAR10(root="./data", train=False, transform=trans(im_size), download=True) elif args.data == 'imagenet64': im_dim = 3 if args.imagesize != 64: args.imagesize = 64 im_size = 64 train_set = Imagenet64(train=True, root='/mnt/data/scratch/data/', transform=tforms.ToTensor()) test_set = Imagenet64(train=False, root='/mnt/data/scratch/data/', transform=tforms.ToTensor()) elif args.data == 'celebahq': im_dim = 3 im_size = 256 if args.imagesize is None else args.imagesize train_set = CelebAHQ(train=True, root='/mnt/data/scratch/data/', transform=tforms.Compose([ tforms.ToPILImage(), tforms.Resize(im_size), tforms.RandomHorizontalFlip(), tforms.ToTensor() ])) test_set = CelebAHQ(train=False, root='/mnt/data/scratch/data/', transform=tforms.Compose([ tforms.ToPILImage(), tforms.Resize(im_size), tforms.ToTensor() ])) data_shape = (im_dim, im_size, im_size) test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=args.test_batch_size, shuffle=False, drop_last=True) return train_set, test_loader, data_shape