Exemple #1
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
    ])

    #trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
    #trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    #testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
    #testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    kwargs = {'num_workers':8,'pin_memory':False}
    #trainloader = torch.utils.data.DataLoader(datasets.MNIST('./data',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),])),batch_size=args.batch_size,shuffle=True,**kwargs)
    #testloader = torch.utils.data.DataLoader(datasets.MNIST('./data',train=False,transform=transforms.Compose([transforms.ToTensor(),])),batch_size=args.batch_size,shuffle=True,**kwargs)

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),])

    #dset = CustomImageFolder('data/CelebA', transform)
    #trainloader = torch.utils.data.DataLoader(dset,batch_size=args.batch_size,shuffle=True,num_workers=8,pin_memory=True,drop_last=True)
    #testloader = torch.utils.data.DataLoader(dset,batch_size=args.batch_size,shuffle=True,num_workers=8,pin_memory=True,drop_last=True)

    trainloader = torch.utils.data.DataLoader(datasets.CelebA('./data',split='train',download=True,transform=transform),batch_size=args.batch_size,shuffle=True,**kwargs)
    testloader = torch.utils.data.DataLoader(datasets.CelebA('./data',split='test',transform=transform),batch_size=args.batch_size,shuffle=True,**kwargs)

    # Model
    print('Building model..')
    net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples)
Exemple #2
0
def get_celeba_loaders(batch_train, batch_test):
    test_num = 128
    images = glob.glob(os.path.join(".", "data", "celeba", "images", "*.jpg"))

    datasets_list = {
        "train": datasets.CelebA('./data', split="train", target_type='attr', download=False,
                       transform=transforms.Compose([
                        transforms.Resize((64, 64)),
                        transforms.ToTensor()
                        ]))

        ,

        "test": datasets.CelebA('./data', split="test", target_type='attr', download=False,
                       transform=transforms.Compose([
                        transforms.Resize((64, 64)),
                        transforms.ToTensor()
                        ]))
        }

    dataloaders_list = {
        "train": DataLoader(datasets_list["train"], batch_size=batch_train, shuffle=True),
        "test": DataLoader(datasets_list["test"], batch_size=batch_test, shuffle=False)
    }

    return dataloaders_list
Exemple #3
0
 def create_datasets(self):
     # Taken from pytorch MNIST demo.
     kwargs = {
         'num_workers': 1,
         'pin_memory': True
     } if self.use_cuda else {}
     transform = transforms.Compose([
         transforms.Resize(IMG_SIZE),
         transforms.CenterCrop(IMG_SIZE),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])
     train_loader = torch.utils.data.DataLoader(
         datasets.CelebA('./celeba_data',
                         split='train',
                         download=True,
                         transform=transform),
         batch_size=self.args.batch,
         shuffle=True,
         **kwargs)
     test_loader = torch.utils.data.DataLoader(datasets.CelebA(
         './celeba_data', split='test', download=True, transform=transform),
                                               batch_size=self.args.batch,
                                               shuffle=True,
                                               **kwargs)
     return train_loader, test_loader
Exemple #4
0
    def __init__(self, root="", transform="default", download=False):

        self.root = root
        if isinstance(transform, str) and transform == "default":
            self.trans = transforms.ToTensor()
        elif isinstance(transform, str) and transform == "vae":
            """
            transforms from https://github.com/AntixK/PyTorch-VAE/blob/8700d245a9735640dda458db4cf40708caf2e77f/experiment.py#L14
            """
            SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
            self.trans = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.CenterCrop(148),
                transforms.Resize(64),
                transforms.ToTensor(), SetRange
            ])
        else:
            self.trans = transform

        if download:
            self.download()

        self.__train_set = datasets.CelebA(root=str(root),
                                           split="train",
                                           transform=self.trans,
                                           download=False)
        self.__test_set = datasets.CelebA(root=str(root),
                                          split="test",
                                          transform=self.trans,
                                          download=False)
        self.__valid_set = datasets.CelebA(root=str(root),
                                           split="valid",
                                           transform=self.trans,
                                           download=False)
Exemple #5
0
    def setup_data_loaders(self):
        if self.dataset == 'celeba':
            transform_list = [transforms.CenterCrop(140), transforms.Resize((64,64),PIL.Image.ANTIALIAS), transforms.ToTensor()]
            #if self.input_normalize_sym:
                #D = 64*64*3
                #transform_list.append(transforms.LinearTransformation(2*torch.eye(D), -.5*torch.ones(D)))
            transform = transforms.Compose(transform_list)
            train_dataset = datasets.CelebA(self.datadir, split='train', target_type='attr', download=True, transform=transform)
            test_dataset = datasets.CelebA(self.datadir, split='test', target_type='attr', download=True, transform=transform)
            self.nlabels = 0
        elif self.dataset == 'mnist':
            train_dataset = datasets.MNIST(self.datadir, train=True, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()]))
            test_dataset = datasets.MNIST(self.datadir, train=False, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()]))
            self.nlabels = 10
        elif self.dataset == 'fashionmnist':
            train_dataset = datasets.FashionMNIST(self.datadir, train=True, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()]))
            test_dataset = datasets.FashionMNIST(self.datadir, train=False, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()]))
            self.nlabels = 10
        elif self.dataset == 'kmnist':
            train_dataset = datasets.KMNIST(self.datadir, train=True, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()]))
            test_dataset = datasets.KMNIST(self.datadir, train=False, target_transform=None, download=True, transform=transforms.Compose([transforms.ToTensor()]))
            self.nlabels = 10
        else:
            raise Exception("Dataset not found: " + dataset)

        if self.limit_train_size is not None:
            train_dataset = torch.utils.data.random_split(train_dataset, [self.limit_train_size, len(train_dataset)-self.limit_train_size])[0]

        self.train_loader = torch.utils.data.DataLoader(DatasetWithIndices(train_dataset, self.input_normalize_sym), batch_size=self.batch_size, shuffle=True, **self.dataloader_kwargs)
        self.test_loader = torch.utils.data.DataLoader(DatasetWithIndices(test_dataset, self.input_normalize_sym), batch_size=self.batch_size, shuffle=False, **self.dataloader_kwargs)
Exemple #6
0
    def get_celeba_loaders(self, data_augment):
        train_dataset = datasets.CelebA(self.data_root, split='train', download=True,
                                        transform=SimCLRDataTransform(data_augment))
        valid_dataset = datasets.CelebA(self.data_root, split='valid', download=True,
                                        transform=SimCLRDataTransform(data_augment))

        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, 
                                    num_workers=self.num_workers, drop_last=True, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, 
                                    num_workers=self.num_workers, drop_last=True)
        return train_loader, valid_loader
Exemple #7
0
def encode(save_root, model_file, data_folder, model_name='ca', dataset_name='celeba', batch_size=64, device='cuda:0', out_dim=256):
    os.makedirs(save_root, exist_ok=True)
    os.makedirs(data_folder, exist_ok=True)

    if dataset_name == 'celeba':
        train_loader = DataLoader(datasets.CelebA(data_folder, split='train', download=True, transform=transforms.ToTensor()),
                                    batch_size=batch_size, shuffle=False)
        valid_loader = DataLoader(datasets.CelebA(data_folder, split='valid', download=True, transform=transforms.ToTensor()),
                                    batch_size=batch_size, shuffle=False)
    elif dataset_name == 'stanfordCars':
        t = transforms.Compose([
            transforms.Resize(512),
            transforms.CenterCrop(512),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3,1,1) if x.shape[0] == 1 else x)
        ])
        train_data_dir = os.path.join(data_folder, 'cars_train/')
        train_annos = os.path.join(data_folder, 'devkit/cars_train_annos.mat')
        train_loader = DataLoader(CarsDataset(train_annos, train_data_dir, t), batch_size=batch_size, shuffle=False)
        valid_data_dir = os.path.join(data_folder, 'cars_test/')
        valid_annos = os.path.join(data_folder, 'devkit/cars_test_annos_withlabels.mat')
        valid_loader = DataLoader(CarsDataset(valid_annos, valid_data_dir, t), batch_size=batch_size, shuffle=False)
    elif dataset_name == 'compCars':
        t = transforms.Compose([
            transforms.Resize(512),
            transforms.CenterCrop(512),
            transforms.ToTensor()
        ])
        train_loader = DataLoader(CompCars(data_folder, True, t), batch_size=batch_size, shuffle=False)
        valid_loader = DataLoader(CompCars(data_folder, False, t), batch_size=batch_size, shuffle=False)


    model = ResNetSimCLR('resnet50', out_dim)
    model.load_state_dict(torch.load(model_file, map_location=device))
    model = model.to(device)
    model.eval()

    print('Starting on training data')
    train_encodings = []
    for x, _ in train_loader:
        x = x.to(device)
        h, _ = model(x)
        train_encodings.append(h.cpu().detach())
    torch.save(torch.cat(train_encodings, dim=0), os.path.join(save_root, f'{dataset_name}-{model_name}model-train_encodings.pt'))

    print('Starting on validation data')
    valid_encodings = []
    for x, _ in valid_loader:
        x = x.to(device)
        h, _ = model(x)
        if len(h.shape) == 1:
            h = h.unsqueeze(0)
        valid_encodings.append(h.cpu().detach())
    torch.save(torch.cat(valid_encodings, dim=0), os.path.join(save_root, f'{dataset_name}-{model_name}model-valid_encodings.pt'))
Exemple #8
0
def get_data_CelebA(transform, batch_size, download = True, root = "/data"):
    print("Loading trainset...")
    trainset = Datasets.CelebA(root=root, split="train", transform=transform, download=download)

    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)

    print("Loading testset...")
    testset = Datasets.CelebA(root=root, split='test', download=download, transform=transform)

    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=8)
    print("Done!")

    return trainloader, testloader
Exemple #9
0
    def get_loader(self, sz, bs, dt=None, num_workers=10):
        if (dt is None):
            dt = get_default_dt(self.basic_types, sz)

        if (self.basic_types is None):
            train_dataset = datasets.ImageFolder(self.train_dir, dt)
        elif (self.basic_types == 'MNIST'):
            train_dataset = datasets.MNIST(self.train_dir,
                                           train=True,
                                           download=True,
                                           transform=dt)
        elif (self.basic_types == 'CIFAR10'):
            train_dataset = datasets.CIFAR10(self.train_dir,
                                             train=True,
                                             download=True,
                                             transform=dt)
        elif (self.basic_types == 'CelebA'):
            train_dataset = datasets.CelebA(self.train_dir,
                                            download=True,
                                            transform=dt)

        train_loader = DataLoader(train_dataset,
                                  batch_size=bs,
                                  shuffle=self.shuffle,
                                  num_workers=num_workers)

        return train_loader
Exemple #10
0
def make_dataset(op):
    print('Downloading data...')
    dataset_name = op.data.name

    if dataset_name == 'celeba':
        return dset.CelebA(root='celeba',
                           download=True,
                           transform=make_transforms(op))

    elif dataset_name == 'fashion_mnist':
        return dset.FashionMNIST(root='fashion_mnist',
                                 download=True,
                                 transform=make_transforms(op))

    elif dataset_name == 'pokemon':
        return PokeSprites(op)

    elif dataset_name in Pix2Pix_Datasets.keys():
        return Pix2Pix(root=dataset_name,
                       dataset_name=dataset_name,
                       download=True,
                       transform=make_transforms(op))

    else:
        raise ValueError(f'{dataset_name} not supported!')
Exemple #11
0
def get_celeba_loaders(path, batch_size, image_size):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    train_set = datasets.CelebA(path,
                                split='train',
                                download=True,
                                transform=transform)
    loader = DataLoader(train_set,
                        shuffle=True,
                        batch_size=batch_size,
                        num_workers=4)
    loader = iter(loader)

    while True:
        try:
            yield next(loader)

        except StopIteration:
            loader = DataLoader(train_set,
                                shuffle=True,
                                batch_size=batch_size,
                                num_workers=4)
            loader = iter(loader)
            yield next(loader)
Exemple #12
0
    def __init__(self,
                 root,
                 train=True,
                 transform=None,
                 size=[32, 32],
                 num_points=200,
                 eval_mode='none',
                 num_points_eval=200,
                 download=False):
        self.name = 'CELEBA'
        split = 'train' if train else 'test'
        self.size = size  #Original size: (218, 178)
        self.num_points = num_points
        self.eval_mode = eval_mode
        self.num_points_eval = num_points_eval

        if transform is None:
            transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

        self.dataset = datasets.CelebA(root=root,
                                       split=split,
                                       download=download,
                                       transform=transform)

        self.coordinates = torch.from_numpy(
            np.array([[int(i / size[1]) / size[0], (i % size[1]) / size[1]]
                      for i in range(size[0] * size[1])])).float()
    def __init__(self, root, img_size=64, data='celeb-a'):
        super().__init__()

        self.root = root
        self.img_size = img_size

        ## Train Set
        self.tr_data = data

        ## Transform functions
        self.transform = Compose([
            Resize(self.img_size, self.img_size),
            Normalize(mean=[0.5], std=[0.5])
        ])

        if self.tr_data == 'celeb-a':
            self.data = datasets.CelebA(self.root,
                                        split='train',
                                        download=True)
        elif self.tr_data == 'mnist':
            self.data = datasets.MNIST(self.root, train=True, download=True)
        else:
            self.data_list = []
            for image in os.listdir(self.root):
                self.data_list.append(self.root + '/' + image)
Exemple #14
0
def get_dataloader(dataset, args):
    if dataset == 'celebA':
        m = (0.5, 0.5, 0.5)
        s = (0.5, 0.5, 0.5)
    else:
        m = (0.5, )
        s = (0.5, )
    transform_img = transforms.Compose([
        transforms.Resize(size=args.img_size, interpolation=0),
        transforms.ToTensor(),
        transforms.Normalize(mean=m, std=s),
    ])

    if dataset == 'mnist':
        dataset = datasets.MNIST(root='./DATA',
                                 train=True,
                                 transform=transform_img,
                                 download=True)
    elif dataset == 'fashionmnist':
        dataset = datasets.FashionMNIST(root='./DATA',
                                        train=True,
                                        download=True,
                                        transform=transform_img)
    elif dataset == 'celebA':
        dataset = datasets.CelebA(root='./DATA',
                                  split='train',
                                  download=True,
                                  transform=transform_img)

    loader = DataLoader(dataset=dataset,
                        batch_size=args.batch,
                        shuffle=True,
                        num_workers=4)
    return loader
Exemple #15
0
    def __init__(self, root, split, normal_class, normal=True, transform=None, abnormal_class=None,
                 extended_attribute_list=False):
        assert normal_class == 0

        self.data = datasets.CelebA(root, split)
        self.transform = transform

        if extended_attribute_list:
            self.attributes = ["Bags_Under_Eyes", "Bald", "Bangs", "Eyeglasses", "Goatee",
                                                  "Heavy_Makeup", "Mustache", "Sideburns", "Wearing_Hat"]
        else:
            self.attributes = ["Bald", "Mustache", "Bangs", "Eyeglasses", "Wearing_Hat"]

        if normal:
            byte_index = torch.ones(len(self.data), dtype=torch.bool)
            for attr_name in self.attributes:
                byte_index = byte_index.logical_and(self.data.attr[:, self.data.attr_names.index(attr_name)] == 0)
            self.active_indexes = torch.nonzero(byte_index, as_tuple=False).numpy().flatten()
        else:
            assert abnormal_class in self.attributes
            # filter images where this attribute is presented
            byte_index = self.data.attr[:, self.data.attr_names.index(abnormal_class)] == 1
            # filter images where all other attributes are not presented
            for attr_name in self.attributes:
                if attr_name != abnormal_class:
                    byte_index = byte_index.logical_and(self.data.attr[:, self.data.attr_names.index(attr_name)] == 0)

            self.active_indexes = torch.nonzero(byte_index, as_tuple=False).numpy().flatten()

        if split == 'train':
            self._min_dataset_size = 10000  # as required in _NormalAnomalyBase
        else:
            self._min_dataset_size = 0
Exemple #16
0
def load_dataset(dataset_name, root='data'):
    img_preprocess = preprocess(dataset_name)

    if dataset_name == "mnist":
        train_dataset = datasets.MNIST(root,
                                       train=True,
                                       download=True,
                                       transform=img_preprocess)
        val_dataset = datasets.MNIST(root,
                                     train=False,
                                     download=True,
                                     transform=img_preprocess)
        label_names = list(map(str, range(10)))

    elif dataset_name == "celeba":
        # Note: download=True fails when daily quota on this dataset has been reached
        train_dataset = datasets.CelebA(root,
                                        split='train',
                                        target_type='attr',
                                        transform=img_preprocess,
                                        download=True)
        val_dataset = datasets.CelebA(root,
                                      split='valid',
                                      target_type='attr',
                                      transform=img_preprocess,
                                      download=True)
        label_names = []

    elif dataset_name == "cifar10":
        train_dataset = datasets.CIFAR10(root,
                                         train=True,
                                         download=True,
                                         transform=img_preprocess)
        val_dataset = datasets.CIFAR10(root,
                                       train=False,
                                       download=True,
                                       transform=img_preprocess)
        label_names = [
            'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
            'horse', 'ship', 'truck'
        ]

    else:
        print("Invalid dataset name, exiting...")
        return

    return train_dataset, val_dataset, label_names
 def prepare_data(self):
   path = os.getcwd() 
   dataset = dset.CelebA(path, split='train', transform=transforms.Compose([
                              transforms.Resize(128),
                              transforms.CenterCrop(128),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), target_transform=None, target_type='attr',
                              download = False)
   self.dataset = dataset
    def prepare_data(self, *args, **kwargs):
        self.train_data = datasets.CelebA(
            self.data_root,
            target_type='attr',
            split='train',
            download=self.download,
            target_transform=self.__target_transform,
            transform=self.__transform)

        self.val_data = datasets.CelebA(
            self.data_root,
            target_type='attr',
            split='valid',
            download=self.download,
            target_transform=self.__target_transform,
            transform=self.__transform)

        self.__set_group_counts()
def load_celeba(batch_size=16):
    transform = transforms.Compose([
        transforms.CenterCrop(89),
        transforms.Resize(32),
        transforms.ToTensor()
    ])
    dataset = ds.CelebA('./data', transform=transform, split='train', download=True)
    data_loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    return data_loader
Exemple #20
0
    def load_data(self, dataset, batch_size) -> None:
        download = False if os.path.exists('../data/'+dataset) else True
        if dataset == 'MNIST':
            self.train_dataset = datasets.MNIST(root='../data', train = True, 
                                                transform=ToTensor(), download=download)
        elif dataset == 'CelebA':
            self.train_dataset = datasets.CelebA(root='../data', split='train', 
                                                 transform=ToTensor(), download=download)

        self.train_dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
Exemple #21
0
 def raw_dataset(self, data_dir: str, download: bool, split: str,
                 transform):
     assert split in ['train', 'val', 'test']
     if split == 'val':
         split = 'valid'
     return datasets.CelebA(data_dir,
                            download=download,
                            split=split,
                            transform=transform,
                            target_type=self.target_type)
Exemple #22
0
    def __init__(self, data_root, debug, cuda_enabled, quiet, checkpoint):
        # Hyperparams
        self.batch_size = 64
        self.epochs = 100
        self.z_dim = 100
        self.lr = 0.0002

        self.image_size = 64

        self.cuda_enabled = cuda_enabled
        self.quiet = quiet
        self.checkpoint_dir = 'checkpoints'
        if not os.path.isdir(self.checkpoint_dir):
            os.mkdir(self.checkpoint_dir)

        # Data augmentations
        transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        # Instantiate data loaders
        self.train_dataset = datasets.CelebA(root=data_root, split='all', download=True, transform=transform)
        self.train_loader = torch.utils.data.DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=True)
        self.num_channels = 3

        # Setup discriminator
        self.d = D(self.num_channels)
        if self.cuda_enabled:
            self.d.cuda()

        # Setup generator
        self.g = G(self.z_dim, self.num_channels)
        if self.cuda_enabled:
            self.g.cuda()

        # Load checkpoint
        self.cur_epoch = 0
        if checkpoint is not None:
            state_dict = torch.load(checkpoint)
            self.g.load_state_dict(state_dict['g'])
            self.d.load_state_dict(state_dict['d'])
            self.cur_epoch = int(os.path.basename(checkpoint).split('_')[0])

        # Setup loss and optimizers
        self.loss = nn.BCELoss()
        self.g_opt = optim.Adam(self.g.parameters(), lr=self.lr, betas=(0.5, 0.999))
        self.d_opt = optim.Adam(self.d.parameters(), lr=self.lr, betas=(0.5, 0.999))

        # Setup options
        if debug:
            self.epochs = 1
Exemple #23
0
    def __init__(self):
        super(CelebA, self).__init__()
        self.binarized = config.get("data", "binarized", default=False)
        self.gauss_noise = config.get("data", "gauss_noise", default=False)
        self.noise_std = config.get("data", "noise_std", default=0.01)
        self.flattened = config.get("data", "flattened", default=False)

        if self.binarized:
            assert not self.gauss_noise

        self.data_path = os.path.join("workspace", "datasets", "celeba")

        # Get datasets
        im_transformer = transforms.Compose([
            transforms.CenterCrop(140),
            transforms.Resize(32),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor()
        ])
        self.train = datasets.CelebA(self.data_path,
                                     split="train",
                                     target_type=[],
                                     transform=im_transformer,
                                     download=True)
        self.val = datasets.CelebA(self.data_path,
                                   split="valid",
                                   target_type=[],
                                   transform=im_transformer,
                                   download=True)
        self.test = datasets.CelebA(self.data_path,
                                    split="test",
                                    target_type=[],
                                    transform=im_transformer,
                                    download=True)

        self.train = UnlabeledDataset(self.train)
        self.test = UnlabeledDataset(self.test)
        self.val = UnlabeledDataset(self.val)
Exemple #24
0
def get_data_loader():
    dataset = dset.CelebA(root=args.celeba_loc,
                          split="all",
                          download=False,
                          transform=transforms.Compose([
                              transforms.Resize(image_size),
                              transforms.CenterCrop(image_size),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5),
                                                   (0.5, 0.5, 0.5)),
                          ]))
    return torch.utils.data.DataLoader(dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=workers)
def main():
    tfms = transforms.Compose(
        [transforms.Resize(size=224),
         transforms.ToTensor()])
    train_dataset = datasets.CelebA(root='data/',
                                    split='train',
                                    target_type='attr',
                                    transform=tfms,
                                    download=True)
    valid_dataset = datasets.CelebA(root='data/',
                                    split='valid',
                                    target_type='attr',
                                    transform=tfms,
                                    download=True)

    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=4, shuffle=False)

    # visualize a batch
    if DEBUG:
        for batch_idx, (images, labels) in enumerate(train_dataloader):
            if batch_idx > 0:
                break
            img = make_grid(images)
            img = img.numpy().transpose(1, 2, 0)
            plt.imshow(img)
            print(labels)
            plt.show()

    # Training loop
    model = Resnet18(num_classes=2)
    optimizer = torch.optim.Adam(model.parameters())



    import ipdb; ipdb.set_trace()  # noqa  # yapf: disable
Exemple #26
0
    def setup_datasets(self):

      dsb_train = datasets.CelebA(self.dspath,
                                  download=False,
                                  split='train',
                                  transform=transforms.Compose([
                                    transforms.Resize(self.shape),
                                    transforms.CenterCrop(self.shape),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                         std=[0.5, 0.5, 0.5])
                              ]))

      dsb_val = datasets.CelebA(self.dspath,
                                  download=False,
                                  transform=transforms.Compose([
                                    transforms.Resize(self.shape),
                                    transforms.CenterCrop(self.shape),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                         std=[0.5, 0.5, 0.5])
                              ]))

      return dsb_train, dsb_val
Exemple #27
0
def celebA_feature_set(split='train',
                       shuffle=None,
                       batch_size=128,
                       attributes=None,
                       augm_type='default',
                       out_size=224,
                       config_dict=None):
    if split == 'test' and not augm_type == 'none':
        print('WARNING: Test set in use with data augmentation')

    if shuffle is None:
        if split == 'train':
            shuffle = True
        else:
            shuffle = False

    if attributes is None:
        attributes = celebA_attributes
        target_transform = None
    else:
        target_transform = get_celebA_target_transform(attributes)

    augm_config = {}
    augm = get_celebA_augmentation(augm_type,
                                   out_size=out_size,
                                   config_dict=augm_config)

    path = get_celebA_path()
    dataset = datasets.CelebA(path,
                              split=split,
                              target_type='attr',
                              transform=augm,
                              target_transform=target_transform,
                              download=False)

    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=shuffle,
                                         num_workers=8)

    if config_dict is not None:
        config_dict['Dataset'] = 'CelebA'
        config_dict['Batch size'] = batch_size
        config_dict['Augmentation'] = augm_config
        config_dict['Attributes'] = attributes

    return loader
Exemple #28
0
    def __init__(self,
                 split,
                 dir_name='data/',
                 subset_percentage=1,
                 protected_percentage=1,
                 balance_protected=True):
        self.transform_image = transform_image
        self.gender_index = 20
        self.dataset = datasets.CelebA(dir_name,
                                       split=split,
                                       transform=transform_image,
                                       target_transform=None,
                                       download=True)

        if subset_percentage < 1:
            self.dataset = Subset(
                self.dataset,
                range(ceil(subset_percentage * len(self.dataset))))

        # Handle protected split (only relevant for train).
        self.protected = np.zeros(len(self.dataset))
        if split == 'train':
            if balance_protected:
                # Get gender information.
                genders = np.array(
                    [self.dataset[i][1][20] for i in range(len(self.dataset))])
                male_idxs = np.where(genders == 1)[0].tolist()
                female_idxs = np.where(genders == 0)[0].tolist()
                # Set number of protected data points.
                max_percentage = min(
                    len(male_idxs), len(female_idxs)) * 2.0 / len(self.dataset)
                if protected_percentage > max_percentage:
                    protected_percentage = max_percentage
                num_protected = ceil(protected_percentage * len(self.dataset))
                if num_protected % 2 == 1: num_protected -= 1
                # Create protected split.
                self.protected_split = random.sample(male_idxs,
                                                     int(num_protected / 2))
                self.protected_split.extend(
                    random.sample(female_idxs, int(num_protected / 2)))
                self.protected[self.protected_split] = 1
            else:
                num_protected = ceil(protected_percentage * len(self.dataset))
                self.protected_split = random.sample(range(len(self.dataset)),
                                                     num_protected)
                self.protected[self.protected_split] = 1
Exemple #29
0
def load_CelebA(img_size=64,
                batch_size=64,
                img_path="./dataset",
                intensity=1.0):
    train_loader = torch.utils.data.DataLoader(
        datasets.CelebA(
            root=img_path + "/celebA",
            # split="all",
            download=True,
            transform=transforms.Compose([
                transforms.Resize(img_size),
                transforms.CenterCrop(img_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])),
        batch_size=batch_size,
        shuffle=True)

    return {'train': train_loader, 'test': False}
Exemple #30
0
    def __init__(self, name, mode=None):
        super(Dataset, self).__init__()

        if name == 'mnist':
            self.dataset = datasets.MNIST('data/MNIST/',
                                          train=True,
                                          download=True,
                                          transform=transforms.Compose([
                                              transforms.Resize(64),
                                              transforms.ToTensor(),
                                              transforms.Normalize(
                                                  mean=(0.5, ), std=(0.5, ))
                                          ]))
        elif mode == 'colab':
            self.dataset = datasets.ImageFolder(
                '/content/gdrive/My Drive/celeba/',
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(64),
                    transforms.CenterCrop(64),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                         std=(0.5, 0.5, 0.5))
                ]))

        else:
            self.dataset = datasets.CelebA('data/CelebA/',
                                           download=True,
                                           transform=transforms.Compose([
                                               transforms.Resize(64),
                                               transforms.CenterCrop(64),
                                               transforms.ToTensor(),
                                               transforms.Normalize(
                                                   mean=(0.5, 0.5, 0.5),
                                                   std=(0.5, 0.5, 0.5))
                                           ]))