def main(args): device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu' start_epoch = 0 # Note: No normalization applied, since RealNVP expects inputs in (0, 1). transform_train = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor() ]) transform_test = transforms.Compose([ transforms.ToTensor() ]) #trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train) #trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) #testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test) #testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) kwargs = {'num_workers':8,'pin_memory':False} #trainloader = torch.utils.data.DataLoader(datasets.MNIST('./data',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),])),batch_size=args.batch_size,shuffle=True,**kwargs) #testloader = torch.utils.data.DataLoader(datasets.MNIST('./data',train=False,transform=transforms.Compose([transforms.ToTensor(),])),batch_size=args.batch_size,shuffle=True,**kwargs) transform = transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor(),]) #dset = CustomImageFolder('data/CelebA', transform) #trainloader = torch.utils.data.DataLoader(dset,batch_size=args.batch_size,shuffle=True,num_workers=8,pin_memory=True,drop_last=True) #testloader = torch.utils.data.DataLoader(dset,batch_size=args.batch_size,shuffle=True,num_workers=8,pin_memory=True,drop_last=True) trainloader = torch.utils.data.DataLoader(datasets.CelebA('./data',split='train',download=True,transform=transform),batch_size=args.batch_size,shuffle=True,**kwargs) testloader = torch.utils.data.DataLoader(datasets.CelebA('./data',split='test',transform=transform),batch_size=args.batch_size,shuffle=True,**kwargs) # Model print('Building model..') net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net, args.gpu_ids) cudnn.benchmark = args.benchmark if args.resume: # Load checkpoint. print('Resuming from checkpoint at ckpts/best.pth.tar...') assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!' checkpoint = torch.load('ckpts/best.pth.tar') net.load_state_dict(checkpoint['net']) global best_loss best_loss = checkpoint['test_loss'] start_epoch = checkpoint['epoch'] loss_fn = RealNVPLoss() param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g') optimizer = optim.Adam(param_groups, lr=args.lr) for epoch in range(start_epoch, start_epoch + args.num_epochs): train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm) test(epoch, net, testloader, device, loss_fn, args.num_samples)
def get_celeba_loaders(batch_train, batch_test): test_num = 128 images = glob.glob(os.path.join(".", "data", "celeba", "images", "*.jpg")) datasets_list = { "train": datasets.CelebA('./data', split="train", target_type='attr', download=False, transform=transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor() ])) , "test": datasets.CelebA('./data', split="test", target_type='attr', download=False, transform=transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor() ])) } dataloaders_list = { "train": DataLoader(datasets_list["train"], batch_size=batch_train, shuffle=True), "test": DataLoader(datasets_list["test"], batch_size=batch_test, shuffle=False) } return dataloaders_list
def create_datasets(self): # Taken from pytorch MNIST demo. kwargs = { 'num_workers': 1, 'pin_memory': True } if self.use_cuda else {} transform = transforms.Compose([ transforms.Resize(IMG_SIZE), transforms.CenterCrop(IMG_SIZE), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_loader = torch.utils.data.DataLoader( datasets.CelebA('./celeba_data', split='train', download=True, transform=transform), batch_size=self.args.batch, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(datasets.CelebA( './celeba_data', split='test', download=True, transform=transform), batch_size=self.args.batch, shuffle=True, **kwargs) return train_loader, test_loader
def __init__(self, root="", transform="default", download=False): self.root = root if isinstance(transform, str) and transform == "default": self.trans = transforms.ToTensor() elif isinstance(transform, str) and transform == "vae": """ transforms from https://github.com/AntixK/PyTorch-VAE/blob/8700d245a9735640dda458db4cf40708caf2e77f/experiment.py#L14 """ SetRange = transforms.Lambda(lambda X: 2 * X - 1.) self.trans = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.CenterCrop(148), transforms.Resize(64), transforms.ToTensor(), SetRange ]) else: self.trans = transform if download: self.download() self.__train_set = datasets.CelebA(root=str(root), split="train", transform=self.trans, download=False) self.__test_set = datasets.CelebA(root=str(root), split="test", transform=self.trans, download=False) self.__valid_set = datasets.CelebA(root=str(root), split="valid", transform=self.trans, download=False)
def setup_data_loaders(self): if self.dataset == 'celeba': transform_list = [transforms.CenterCrop(140), transforms.Resize((64,64),PIL.Image.ANTIALIAS), transforms.ToTensor()] #if self.input_normalize_sym: #D = 64*64*3 #transform_list.append(transforms.LinearTransformation(2*torch.eye(D), -.5*torch.ones(D))) transform = transforms.Compose(transform_list) train_dataset = datasets.CelebA(self.datadir, split='train', target_type='attr', download=True, transform=transform) test_dataset = datasets.CelebA(self.datadir, split='test', target_type='attr', download=True, transform=transform) self.nlabels = 0 elif self.dataset == 'mnist': train_dataset = datasets.MNIST(self.datadir, train=True, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()])) test_dataset = datasets.MNIST(self.datadir, train=False, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()])) self.nlabels = 10 elif self.dataset == 'fashionmnist': train_dataset = datasets.FashionMNIST(self.datadir, train=True, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()])) test_dataset = datasets.FashionMNIST(self.datadir, train=False, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()])) self.nlabels = 10 elif self.dataset == 'kmnist': train_dataset = datasets.KMNIST(self.datadir, train=True, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()])) test_dataset = datasets.KMNIST(self.datadir, train=False, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()])) self.nlabels = 10 else: raise Exception("Dataset not found: " + dataset) if self.limit_train_size is not None: train_dataset = torch.utils.data.random_split(train_dataset, [self.limit_train_size, len(train_dataset)-self.limit_train_size])[0] self.train_loader = torch.utils.data.DataLoader(DatasetWithIndices(train_dataset, self.input_normalize_sym), batch_size=self.batch_size, shuffle=True, **self.dataloader_kwargs) self.test_loader = torch.utils.data.DataLoader(DatasetWithIndices(test_dataset, self.input_normalize_sym), batch_size=self.batch_size, shuffle=False, **self.dataloader_kwargs)
def get_celeba_loaders(self, data_augment): train_dataset = datasets.CelebA(self.data_root, split='train', download=True, transform=SimCLRDataTransform(data_augment)) valid_dataset = datasets.CelebA(self.data_root, split='valid', download=True, transform=SimCLRDataTransform(data_augment)) train_loader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True, shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True) return train_loader, valid_loader
def encode(save_root, model_file, data_folder, model_name='ca', dataset_name='celeba', batch_size=64, device='cuda:0', out_dim=256): os.makedirs(save_root, exist_ok=True) os.makedirs(data_folder, exist_ok=True) if dataset_name == 'celeba': train_loader = DataLoader(datasets.CelebA(data_folder, split='train', download=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=False) valid_loader = DataLoader(datasets.CelebA(data_folder, split='valid', download=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=False) elif dataset_name == 'stanfordCars': t = transforms.Compose([ transforms.Resize(512), transforms.CenterCrop(512), transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3,1,1) if x.shape[0] == 1 else x) ]) train_data_dir = os.path.join(data_folder, 'cars_train/') train_annos = os.path.join(data_folder, 'devkit/cars_train_annos.mat') train_loader = DataLoader(CarsDataset(train_annos, train_data_dir, t), batch_size=batch_size, shuffle=False) valid_data_dir = os.path.join(data_folder, 'cars_test/') valid_annos = os.path.join(data_folder, 'devkit/cars_test_annos_withlabels.mat') valid_loader = DataLoader(CarsDataset(valid_annos, valid_data_dir, t), batch_size=batch_size, shuffle=False) elif dataset_name == 'compCars': t = transforms.Compose([ transforms.Resize(512), transforms.CenterCrop(512), transforms.ToTensor() ]) train_loader = DataLoader(CompCars(data_folder, True, t), batch_size=batch_size, shuffle=False) valid_loader = DataLoader(CompCars(data_folder, False, t), batch_size=batch_size, shuffle=False) model = ResNetSimCLR('resnet50', out_dim) model.load_state_dict(torch.load(model_file, map_location=device)) model = model.to(device) model.eval() print('Starting on training data') train_encodings = [] for x, _ in train_loader: x = x.to(device) h, _ = model(x) train_encodings.append(h.cpu().detach()) torch.save(torch.cat(train_encodings, dim=0), os.path.join(save_root, f'{dataset_name}-{model_name}model-train_encodings.pt')) print('Starting on validation data') valid_encodings = [] for x, _ in valid_loader: x = x.to(device) h, _ = model(x) if len(h.shape) == 1: h = h.unsqueeze(0) valid_encodings.append(h.cpu().detach()) torch.save(torch.cat(valid_encodings, dim=0), os.path.join(save_root, f'{dataset_name}-{model_name}model-valid_encodings.pt'))
def get_data_CelebA(transform, batch_size, download = True, root = "/data"): print("Loading trainset...") trainset = Datasets.CelebA(root=root, split="train", transform=transform, download=download) trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8) print("Loading testset...") testset = Datasets.CelebA(root=root, split='test', download=download, transform=transform) testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=8) print("Done!") return trainloader, testloader
def get_loader(self, sz, bs, dt=None, num_workers=10): if (dt is None): dt = get_default_dt(self.basic_types, sz) if (self.basic_types is None): train_dataset = datasets.ImageFolder(self.train_dir, dt) elif (self.basic_types == 'MNIST'): train_dataset = datasets.MNIST(self.train_dir, train=True, download=True, transform=dt) elif (self.basic_types == 'CIFAR10'): train_dataset = datasets.CIFAR10(self.train_dir, train=True, download=True, transform=dt) elif (self.basic_types == 'CelebA'): train_dataset = datasets.CelebA(self.train_dir, download=True, transform=dt) train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=self.shuffle, num_workers=num_workers) return train_loader
def make_dataset(op): print('Downloading data...') dataset_name = op.data.name if dataset_name == 'celeba': return dset.CelebA(root='celeba', download=True, transform=make_transforms(op)) elif dataset_name == 'fashion_mnist': return dset.FashionMNIST(root='fashion_mnist', download=True, transform=make_transforms(op)) elif dataset_name == 'pokemon': return PokeSprites(op) elif dataset_name in Pix2Pix_Datasets.keys(): return Pix2Pix(root=dataset_name, dataset_name=dataset_name, download=True, transform=make_transforms(op)) else: raise ValueError(f'{dataset_name} not supported!')
def get_celeba_loaders(path, batch_size, image_size): transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) train_set = datasets.CelebA(path, split='train', download=True, transform=transform) loader = DataLoader(train_set, shuffle=True, batch_size=batch_size, num_workers=4) loader = iter(loader) while True: try: yield next(loader) except StopIteration: loader = DataLoader(train_set, shuffle=True, batch_size=batch_size, num_workers=4) loader = iter(loader) yield next(loader)
def __init__(self, root, train=True, transform=None, size=[32, 32], num_points=200, eval_mode='none', num_points_eval=200, download=False): self.name = 'CELEBA' split = 'train' if train else 'test' self.size = size #Original size: (218, 178) self.num_points = num_points self.eval_mode = eval_mode self.num_points_eval = num_points_eval if transform is None: transform = transforms.Compose([ transforms.Resize(size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) self.dataset = datasets.CelebA(root=root, split=split, download=download, transform=transform) self.coordinates = torch.from_numpy( np.array([[int(i / size[1]) / size[0], (i % size[1]) / size[1]] for i in range(size[0] * size[1])])).float()
def __init__(self, root, img_size=64, data='celeb-a'): super().__init__() self.root = root self.img_size = img_size ## Train Set self.tr_data = data ## Transform functions self.transform = Compose([ Resize(self.img_size, self.img_size), Normalize(mean=[0.5], std=[0.5]) ]) if self.tr_data == 'celeb-a': self.data = datasets.CelebA(self.root, split='train', download=True) elif self.tr_data == 'mnist': self.data = datasets.MNIST(self.root, train=True, download=True) else: self.data_list = [] for image in os.listdir(self.root): self.data_list.append(self.root + '/' + image)
def get_dataloader(dataset, args): if dataset == 'celebA': m = (0.5, 0.5, 0.5) s = (0.5, 0.5, 0.5) else: m = (0.5, ) s = (0.5, ) transform_img = transforms.Compose([ transforms.Resize(size=args.img_size, interpolation=0), transforms.ToTensor(), transforms.Normalize(mean=m, std=s), ]) if dataset == 'mnist': dataset = datasets.MNIST(root='./DATA', train=True, transform=transform_img, download=True) elif dataset == 'fashionmnist': dataset = datasets.FashionMNIST(root='./DATA', train=True, download=True, transform=transform_img) elif dataset == 'celebA': dataset = datasets.CelebA(root='./DATA', split='train', download=True, transform=transform_img) loader = DataLoader(dataset=dataset, batch_size=args.batch, shuffle=True, num_workers=4) return loader
def __init__(self, root, split, normal_class, normal=True, transform=None, abnormal_class=None, extended_attribute_list=False): assert normal_class == 0 self.data = datasets.CelebA(root, split) self.transform = transform if extended_attribute_list: self.attributes = ["Bags_Under_Eyes", "Bald", "Bangs", "Eyeglasses", "Goatee", "Heavy_Makeup", "Mustache", "Sideburns", "Wearing_Hat"] else: self.attributes = ["Bald", "Mustache", "Bangs", "Eyeglasses", "Wearing_Hat"] if normal: byte_index = torch.ones(len(self.data), dtype=torch.bool) for attr_name in self.attributes: byte_index = byte_index.logical_and(self.data.attr[:, self.data.attr_names.index(attr_name)] == 0) self.active_indexes = torch.nonzero(byte_index, as_tuple=False).numpy().flatten() else: assert abnormal_class in self.attributes # filter images where this attribute is presented byte_index = self.data.attr[:, self.data.attr_names.index(abnormal_class)] == 1 # filter images where all other attributes are not presented for attr_name in self.attributes: if attr_name != abnormal_class: byte_index = byte_index.logical_and(self.data.attr[:, self.data.attr_names.index(attr_name)] == 0) self.active_indexes = torch.nonzero(byte_index, as_tuple=False).numpy().flatten() if split == 'train': self._min_dataset_size = 10000 # as required in _NormalAnomalyBase else: self._min_dataset_size = 0
def load_dataset(dataset_name, root='data'): img_preprocess = preprocess(dataset_name) if dataset_name == "mnist": train_dataset = datasets.MNIST(root, train=True, download=True, transform=img_preprocess) val_dataset = datasets.MNIST(root, train=False, download=True, transform=img_preprocess) label_names = list(map(str, range(10))) elif dataset_name == "celeba": # Note: download=True fails when daily quota on this dataset has been reached train_dataset = datasets.CelebA(root, split='train', target_type='attr', transform=img_preprocess, download=True) val_dataset = datasets.CelebA(root, split='valid', target_type='attr', transform=img_preprocess, download=True) label_names = [] elif dataset_name == "cifar10": train_dataset = datasets.CIFAR10(root, train=True, download=True, transform=img_preprocess) val_dataset = datasets.CIFAR10(root, train=False, download=True, transform=img_preprocess) label_names = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] else: print("Invalid dataset name, exiting...") return return train_dataset, val_dataset, label_names
def prepare_data(self): path = os.getcwd() dataset = dset.CelebA(path, split='train', transform=transforms.Compose([ transforms.Resize(128), transforms.CenterCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), target_transform=None, target_type='attr', download = False) self.dataset = dataset
def prepare_data(self, *args, **kwargs): self.train_data = datasets.CelebA( self.data_root, target_type='attr', split='train', download=self.download, target_transform=self.__target_transform, transform=self.__transform) self.val_data = datasets.CelebA( self.data_root, target_type='attr', split='valid', download=self.download, target_transform=self.__target_transform, transform=self.__transform) self.__set_group_counts()
def load_celeba(batch_size=16): transform = transforms.Compose([ transforms.CenterCrop(89), transforms.Resize(32), transforms.ToTensor() ]) dataset = ds.CelebA('./data', transform=transform, split='train', download=True) data_loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True) return data_loader
def load_data(self, dataset, batch_size) -> None: download = False if os.path.exists('../data/'+dataset) else True if dataset == 'MNIST': self.train_dataset = datasets.MNIST(root='../data', train = True, transform=ToTensor(), download=download) elif dataset == 'CelebA': self.train_dataset = datasets.CelebA(root='../data', split='train', transform=ToTensor(), download=download) self.train_dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
def raw_dataset(self, data_dir: str, download: bool, split: str, transform): assert split in ['train', 'val', 'test'] if split == 'val': split = 'valid' return datasets.CelebA(data_dir, download=download, split=split, transform=transform, target_type=self.target_type)
def __init__(self, data_root, debug, cuda_enabled, quiet, checkpoint): # Hyperparams self.batch_size = 64 self.epochs = 100 self.z_dim = 100 self.lr = 0.0002 self.image_size = 64 self.cuda_enabled = cuda_enabled self.quiet = quiet self.checkpoint_dir = 'checkpoints' if not os.path.isdir(self.checkpoint_dir): os.mkdir(self.checkpoint_dir) # Data augmentations transform = transforms.Compose([ transforms.Resize((self.image_size, self.image_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) # Instantiate data loaders self.train_dataset = datasets.CelebA(root=data_root, split='all', download=True, transform=transform) self.train_loader = torch.utils.data.DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=True) self.num_channels = 3 # Setup discriminator self.d = D(self.num_channels) if self.cuda_enabled: self.d.cuda() # Setup generator self.g = G(self.z_dim, self.num_channels) if self.cuda_enabled: self.g.cuda() # Load checkpoint self.cur_epoch = 0 if checkpoint is not None: state_dict = torch.load(checkpoint) self.g.load_state_dict(state_dict['g']) self.d.load_state_dict(state_dict['d']) self.cur_epoch = int(os.path.basename(checkpoint).split('_')[0]) # Setup loss and optimizers self.loss = nn.BCELoss() self.g_opt = optim.Adam(self.g.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.d_opt = optim.Adam(self.d.parameters(), lr=self.lr, betas=(0.5, 0.999)) # Setup options if debug: self.epochs = 1
def __init__(self): super(CelebA, self).__init__() self.binarized = config.get("data", "binarized", default=False) self.gauss_noise = config.get("data", "gauss_noise", default=False) self.noise_std = config.get("data", "noise_std", default=0.01) self.flattened = config.get("data", "flattened", default=False) if self.binarized: assert not self.gauss_noise self.data_path = os.path.join("workspace", "datasets", "celeba") # Get datasets im_transformer = transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(32), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ]) self.train = datasets.CelebA(self.data_path, split="train", target_type=[], transform=im_transformer, download=True) self.val = datasets.CelebA(self.data_path, split="valid", target_type=[], transform=im_transformer, download=True) self.test = datasets.CelebA(self.data_path, split="test", target_type=[], transform=im_transformer, download=True) self.train = UnlabeledDataset(self.train) self.test = UnlabeledDataset(self.test) self.val = UnlabeledDataset(self.val)
def get_data_loader(): dataset = dset.CelebA(root=args.celeba_loc, split="all", download=False, transform=transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
def main(): tfms = transforms.Compose( [transforms.Resize(size=224), transforms.ToTensor()]) train_dataset = datasets.CelebA(root='data/', split='train', target_type='attr', transform=tfms, download=True) valid_dataset = datasets.CelebA(root='data/', split='valid', target_type='attr', transform=tfms, download=True) train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True) valid_dataloader = DataLoader(valid_dataset, batch_size=4, shuffle=False) # visualize a batch if DEBUG: for batch_idx, (images, labels) in enumerate(train_dataloader): if batch_idx > 0: break img = make_grid(images) img = img.numpy().transpose(1, 2, 0) plt.imshow(img) print(labels) plt.show() # Training loop model = Resnet18(num_classes=2) optimizer = torch.optim.Adam(model.parameters()) import ipdb; ipdb.set_trace() # noqa # yapf: disable
def setup_datasets(self): dsb_train = datasets.CelebA(self.dspath, download=False, split='train', transform=transforms.Compose([ transforms.Resize(self.shape), transforms.CenterCrop(self.shape), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])) dsb_val = datasets.CelebA(self.dspath, download=False, transform=transforms.Compose([ transforms.Resize(self.shape), transforms.CenterCrop(self.shape), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])) return dsb_train, dsb_val
def celebA_feature_set(split='train', shuffle=None, batch_size=128, attributes=None, augm_type='default', out_size=224, config_dict=None): if split == 'test' and not augm_type == 'none': print('WARNING: Test set in use with data augmentation') if shuffle is None: if split == 'train': shuffle = True else: shuffle = False if attributes is None: attributes = celebA_attributes target_transform = None else: target_transform = get_celebA_target_transform(attributes) augm_config = {} augm = get_celebA_augmentation(augm_type, out_size=out_size, config_dict=augm_config) path = get_celebA_path() dataset = datasets.CelebA(path, split=split, target_type='attr', transform=augm, target_transform=target_transform, download=False) loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=8) if config_dict is not None: config_dict['Dataset'] = 'CelebA' config_dict['Batch size'] = batch_size config_dict['Augmentation'] = augm_config config_dict['Attributes'] = attributes return loader
def __init__(self, split, dir_name='data/', subset_percentage=1, protected_percentage=1, balance_protected=True): self.transform_image = transform_image self.gender_index = 20 self.dataset = datasets.CelebA(dir_name, split=split, transform=transform_image, target_transform=None, download=True) if subset_percentage < 1: self.dataset = Subset( self.dataset, range(ceil(subset_percentage * len(self.dataset)))) # Handle protected split (only relevant for train). self.protected = np.zeros(len(self.dataset)) if split == 'train': if balance_protected: # Get gender information. genders = np.array( [self.dataset[i][1][20] for i in range(len(self.dataset))]) male_idxs = np.where(genders == 1)[0].tolist() female_idxs = np.where(genders == 0)[0].tolist() # Set number of protected data points. max_percentage = min( len(male_idxs), len(female_idxs)) * 2.0 / len(self.dataset) if protected_percentage > max_percentage: protected_percentage = max_percentage num_protected = ceil(protected_percentage * len(self.dataset)) if num_protected % 2 == 1: num_protected -= 1 # Create protected split. self.protected_split = random.sample(male_idxs, int(num_protected / 2)) self.protected_split.extend( random.sample(female_idxs, int(num_protected / 2))) self.protected[self.protected_split] = 1 else: num_protected = ceil(protected_percentage * len(self.dataset)) self.protected_split = random.sample(range(len(self.dataset)), num_protected) self.protected[self.protected_split] = 1
def load_CelebA(img_size=64, batch_size=64, img_path="./dataset", intensity=1.0): train_loader = torch.utils.data.DataLoader( datasets.CelebA( root=img_path + "/celebA", # split="all", download=True, transform=transforms.Compose([ transforms.Resize(img_size), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])), batch_size=batch_size, shuffle=True) return {'train': train_loader, 'test': False}
def __init__(self, name, mode=None): super(Dataset, self).__init__() if name == 'mnist': self.dataset = datasets.MNIST('data/MNIST/', train=True, download=True, transform=transforms.Compose([ transforms.Resize(64), transforms.ToTensor(), transforms.Normalize( mean=(0.5, ), std=(0.5, )) ])) elif mode == 'colab': self.dataset = datasets.ImageFolder( '/content/gdrive/My Drive/celeba/', download=True, transform=transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ])) else: self.dataset = datasets.CelebA('data/CelebA/', download=True, transform=transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize( mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]))