Ejemplo n.º 1
0
            output, target, size_average=False).data[0]
        correct = correct + pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(testloader.dataset)
    print(
        '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(testloader.dataset),
            100. * correct / len(testloader.dataset)))


if __name__ == "__main__":

    svhn_dataset_train = SVHN(root='/data02/Atin/DeployedProjects/SVHN',
                              split='train',
                              transform=Compose([
                                  Lambda(addGaussian),
                                  ToTensor(),
                                  Normalize(mean, std)
                              ]))
    svhn_dataset_test = SVHN(root='/data02/Atin/DeployedProjects/SVHN',
                             split='test',
                             download=True,
                             transform=Compose(
                                 [ToTensor(), Normalize(mean, std)]))

    train_dataloader = DataLoader(svhn_dataset_train,
                                  batch_size=64,
                                  num_workers=10,
                                  shuffle=True)
    test_dataloader = DataLoader(svhn_dataset_test,
                                 batch_size=64,
Ejemplo n.º 2
0
    def __init__(self, train=True, augment=False):

        transform = transforms.ToTensor()

        self.data = SVHN("/tmp/svhn", transform=transform, download=True)
        self.one_hot_map = np.eye(10)
def load_data(opt):
    """ Load Data

    Args:
        opt ([type]): Argument Parser

    Raises:
        IOError: Cannot Load Dataset

    Returns:
        [type]: dataloader
    """

    ##
    # LOAD DATA SET
    if opt.dataroot == '':
        opt.dataroot = './data/{}'.format(opt.dataset)

    if opt.dataset in ['cifar10']:
        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': False}

        transform = transforms.Compose([
            transforms.Resize(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        classes = {
            'plane': 0,
            'car': 1,
            'bird': 2,
            'cat': 3,
            'deer': 4,
            'dog': 5,
            'frog': 6,
            'horse': 7,
            'ship': 8,
            'truck': 9
        }

        dataset = {}
        dataset['train'] = CIFAR10(root='./data',
                                   train=True,
                                   download=True,
                                   transform=transform)
        dataset['test'] = CIFAR10(root='./data',
                                  train=False,
                                  download=True,
                                  transform=transform)

        if opt.task == 'anomaly_detect':
            dataset['train'].train_data, dataset['train'].train_labels, \
            dataset['test'].test_data, dataset['test'].test_labels = get_cifar_anomaly_dataset(
                trn_img=dataset['train'].train_data,
                trn_lbl=dataset['train'].train_labels,
                tst_img=dataset['test'].test_data,
                tst_lbl=dataset['test'].test_labels,
                abn_cls_idx=classes[opt.anomaly_class]
            )
        elif opt.task == 'random_walk':
            dataset['train'].train_data, dataset['train'].train_labels, \
            dataset['test'].test_data, dataset['test'].test_labels = get_sub_cifar10_dataset(
                trn_img=dataset['train'].train_data,
                trn_lbl=dataset['train'].train_labels,
                tst_img=dataset['test'].test_data,
                tst_lbl=dataset['test'].test_labels,
            )
        elif opt.task == 'llk_trend':
            dataset['train'].train_data, dataset['train'].train_labels, \
            dataset['test'].test_data, dataset['test'].test_labels = get_sub_cifar10_dataset(
                trn_img=dataset['train'].train_data,
                trn_lbl=dataset['train'].train_labels,
                tst_img=dataset['test'].test_data,
                tst_lbl=dataset['test'].test_labels,
                abn_cls_idx=[0]
            )
        elif opt.task == 'rw_llk':  ##for simplication, let's do it together
            dataset['train'].train_data, dataset['train'].train_labels, \
            dataset['test'].test_data, dataset['test'].test_labels = get_cifar_rwllk_dataset(
                trn_img=dataset['train'].train_data,
                trn_lbl=dataset['train'].train_labels,
                tst_img=dataset['test'].test_data,
                tst_lbl=dataset['test'].test_labels,
                abn_cls_idx=classes[opt.anomaly_class]
            )

        dataloader = {
            x: torch.utils.data.DataLoader(dataset=dataset[x],
                                           batch_size=opt.batch_size,
                                           shuffle=shuffle[x],
                                           num_workers=int(opt.workers),
                                           drop_last=drop_last_batch[x])
            for x in splits
        }
        return dataloader
    elif opt.dataset in ['mnist']:
        opt.anomaly_class = int(opt.anomaly_class)

        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': True}

        ##it seems that the transform in our GAN, we can use identical transform for both cifar10 and mnist
        ##the following is not working well at all.
        # transform = transforms.Compose(
        #     [
        #         transforms.Scale(opt.isize),
        #         transforms.ToTensor(),
        #         transforms.Normalize((0.1307,), (0.3081,))
        #     ]
        # )
        ##the second one is good for mnist
        transform = transforms.Compose([
            transforms.Resize((opt.isize, opt.isize)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        dataset = {}
        dataset['train'] = MNIST(root='./data',
                                 train=True,
                                 download=True,
                                 transform=transform)
        dataset['test'] = MNIST(root='./data',
                                train=False,
                                download=True,
                                transform=transform)

        if opt.task == 'anomaly_detect':
            dataset['train'].train_data, dataset['train'].train_labels, \
            dataset['test'].test_data, dataset['test'].test_labels = get_mnist_anomaly_dataset(
                trn_img=dataset['train'].train_data,
                trn_lbl=dataset['train'].train_labels,
                tst_img=dataset['test'].test_data,
                tst_lbl=dataset['test'].test_labels,
                abn_cls_idx=opt.anomaly_class
            )
        elif opt.task == 'random_walk':
            dataset['train'].train_data, dataset['train'].train_labels, \
            dataset['test'].test_data, dataset['test'].test_labels = get_sub_mnist_dataset(
                trn_img=dataset['train'].train_data,
                trn_lbl=dataset['train'].train_labels,
                tst_img=dataset['test'].test_data,
                tst_lbl=dataset['test'].test_labels,
            )

        elif opt.task == 'llk_trend':
            dataset['train'].train_data, dataset['train'].train_labels, \
            dataset['test'].test_data, dataset['test'].test_labels = get_sub_mnist_dataset(
                trn_img=dataset['train'].train_data,
                trn_lbl=dataset['train'].train_labels,
                tst_img=dataset['test'].test_data,
                tst_lbl=dataset['test'].test_labels,
                abn_cls_idx = [0]
            )
        elif opt.task == 'rw_llk':  ##for simplication, let's do it together
            dataset['train'].train_data, dataset['train'].train_labels, \
            dataset['test'].test_data, dataset['test'].test_labels = get_mnist_rwllk_dataset(
                trn_img=dataset['train'].train_data,
                trn_lbl=dataset['train'].train_labels,
                tst_img=dataset['test'].test_data,
                tst_lbl=dataset['test'].test_labels,
                abn_cls_idx=opt.anomaly_class
            )

        dataloader = {
            x: torch.utils.data.DataLoader(dataset=dataset[x],
                                           batch_size=opt.batch_size,
                                           shuffle=shuffle[x],
                                           num_workers=int(opt.workers),
                                           drop_last=drop_last_batch[x])
            for x in splits
        }

        return dataloader

    elif opt.dataset in ['svhn']:
        opt.anomaly_class = int(opt.anomaly_class)

        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': True}

        ##it seems that the transform in our GAN, we can use identical transform for both cifar10 and mnist
        ##the following is not working well at all.
        # transform = transforms.Compose(
        #     [
        #         transforms.Scale(opt.isize),
        #         transforms.ToTensor(),
        #         transforms.Normalize((0.1307,), (0.3081,))
        #     ]
        # )
        ##the second one is good for mnist
        transform = transforms.Compose([
            transforms.Resize((opt.isize, opt.isize)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        dataset = {}
        dataset['train'] = SVHN(root='./data',
                                split='train',
                                download=True,
                                transform=transform)
        dataset['test'] = SVHN(root='./data',
                               split='test',
                               download=True,
                               transform=transform)

        if opt.task == 'anomaly_detect':
            dataset['train'].train_data, dataset['train'].train_labels, \
            dataset['test'].test_data, dataset['test'].test_labels = get_cifar_anomaly_dataset( ##not sure if we need to write a get_svhn_anomaly_dataset yet.
                trn_img=dataset['train'].data,
                trn_lbl=dataset['train'].labels,
                tst_img=dataset['test'].data,
                tst_lbl=dataset['test'].labels,
                abn_cls_idx=opt.anomaly_class
            )

        dataloader = {
            x: torch.utils.data.DataLoader(dataset=dataset[x],
                                           batch_size=opt.batch_size,
                                           shuffle=shuffle[x],
                                           num_workers=int(opt.workers),
                                           drop_last=drop_last_batch[x])
            for x in splits
        }
        return dataloader

    elif opt.dataset in ['mnist2']:
        opt.anomaly_class = int(opt.anomaly_class)

        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': True}

        transform = transforms.Compose([
            transforms.Scale(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        dataset = {}
        dataset['train'] = MNIST(root='./data',
                                 train=True,
                                 download=True,
                                 transform=transform)
        dataset['test'] = MNIST(root='./data',
                                train=False,
                                download=True,
                                transform=transform)

        if opt.task == 'anomaly_detect':
            dataset['train'].train_data, dataset['train'].train_labels, \
            dataset['test'].test_data, dataset['test'].test_labels = get_mnist2_anomaly_dataset(
                trn_img=dataset['train'].train_data,
                trn_lbl=dataset['train'].train_labels,
                tst_img=dataset['test'].test_data,
                tst_lbl=dataset['test'].test_labels,
                nrm_cls_idx=opt.anomaly_class,
                proportion=opt.proportion
            )

        dataloader = {
            x: torch.utils.data.DataLoader(dataset=dataset[x],
                                           batch_size=opt.batch_size,
                                           shuffle=shuffle[x],
                                           num_workers=int(opt.workers),
                                           drop_last=drop_last_batch[x])
            for x in splits
        }
        return dataloader

    # elif opt.dataset in ['celebA']:
    #     ##not for abnormal detection but not for classification either
    #     splits = ['train', 'test']
    #     drop_last_batch = {'train': True, 'test': False}
    #     shuffle = {'train': True, 'test': True}
    #     transform = transforms.Compose([transforms.Scale(opt.isize),
    #                                     transforms.CenterCrop(opt.isize),
    #                                     transforms.ToTensor(),
    #                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])
    #     print(os.path.abspath('./data/celebA'))
    #     dataset = datasets.ImageFolder(os.path.abspath('./data/celebA'), transform)
    #     dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=opt.batch_size, shuffle=True)
    #     return dataloader
    elif opt.dataset in ['celebA']:
        ##not for abnormal detection but not for classification either
        # import helper
        # helper.download_extract('celeba', opt.dataroot)
        # splits = ['train', 'test']
        # drop_last_batch = {'train': True, 'test': False}
        # shuffle = {'train': True, 'test': True}
        # transform = transforms.Compose([transforms.Scale(opt.isize),
        #                                 transforms.CenterCrop(opt.isize),
        #                                 transforms.ToTensor(),
        #                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])
        #
        # # transform = transforms.Compose([
        # #     transforms.CenterCrop(160),
        # #     transforms.Scale(opt.isize),
        # #     transforms.ToTensor(),)
        #
        # dataset = ImageFolder(root=image_root, transform=transforms.Compose([
        #     transforms.CenterCrop(160),
        #     transforms.Scale(scale_size),
        #     transforms.ToTensor(),
        #     #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        # ]))
        #
        # dataset = {x: ImageFolder(os.path.join(opt.dataroot, x), transform) for x in splits}
        # dataloader = {x: torch.utils.data.DataLoader(dataset=dataset[x],
        #                                              batch_size=opt.batch_size,
        #                                              shuffle=shuffle[x],
        #                                              num_workers=int(opt.workers),
        #                                              drop_last=drop_last_batch[x]) for x in splits}

        dataloader = get_loader('./data/celebA', 'train', opt.batch_size,
                                opt.isize)
        return dataloader
    def forward(self, x_batch, y_batch):
        loss_value = self.loss_function(x_batch, y_batch)
        reg_value = torch.norm(self.weight_vector)
        return loss_value + self.alpha * reg_value


import torch
from torch import nn
import numpy as np
from sklearn.decomposition import NMF
from torchvision.transforms import transforms

from torchvision.datasets import SVHN

SVHN()


class LRFLoss(nn.Module):
    def __init__(self,
                 loss_function: nn.Module,
                 net: nn.Module,
                 verbose=False):
        super().__init__()
        self.loss_function = loss_function
        self.net = net
        self.k = 1
        self.theta_star = list()
        self.verbose = verbose
        p_list = list(self.net.parameters())
        for p in p_list:
d_list = [
    'alice', 'alice-z', 'alice-x', 'vegan', 'vegan-wgan-gp', 'vegan-kl',
    'vegan-ikl', 'vegan-jsd', 'vegan-mmd'
]
if method in d_list:
    distance_x = 'l2'  # l1, l2
lambda_ = 1.  # Balance reconstruction and regularization in vegan

dim = 64  # Model dimensionality
output_dim = 3072  # Number of pixels in svhn (3*32*32)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_dset = SVHN('./data/svhn', 'train', download=True, transform=transform)
test_dset = SVHN('./data/svhn', 'test', download=True, transform=transform)
train_loader = DataLoader(train_dset,
                          num_workers=4,
                          pin_memory=True,
                          batch_size=64,
                          shuffle=True,
                          drop_last=True)
test_loader = DataLoader(test_dset,
                         num_workers=4,
                         pin_memory=True,
                         batch_size=1,
                         shuffle=False)

num_iter = 200000
dim_latent = 128
Ejemplo n.º 6
0
def get_dataset(cfg, fine_tune):
    data_transform = Compose([Resize((32, 32)), ToTensor()])
    mnist_transform = Compose(
        [Resize((32, 32)),
         ToTensor(),
         Lambda(lambda x: swapaxes(x, 1, -1))])
    vade_transform = Compose([ToTensor()])

    if cfg.DATA.DATASET == 'mnist':
        transform = vade_transform if 'vade' in cfg.DIRS.CHKP_PREFIX \
            else mnist_transform

        training_set = MNIST(download=True,
                             root=cfg.DIRS.DATA,
                             transform=transform,
                             train=True)
        val_set = MNIST(download=False,
                        root=cfg.DIRS.DATA,
                        transform=transform,
                        train=False)
        plot_set = copy.deepcopy(val_set)

    elif cfg.DATA.DATASET == 'svhn':
        training_set = SVHN(download=True,
                            root=create_dir(cfg.DIRS.DATA, 'SVHN'),
                            transform=data_transform,
                            split='train')
        val_set = SVHN(download=True,
                       root=create_dir(cfg.DIRS.DATA, 'SVHN'),
                       transform=data_transform,
                       split='test')
        plot_set = copy.deepcopy(val_set)

    elif cfg.DATA.DATASET == 'cifar':
        training_set = CIFAR10(download=True,
                               root=create_dir(cfg.DIRS.DATA, 'CIFAR'),
                               transform=data_transform,
                               train=True)
        val_set = CIFAR10(download=True,
                          root=create_dir(cfg.DIRS.DATA, 'CIFAR'),
                          transform=data_transform,
                          train=False)
        plot_set = copy.deepcopy(val_set)

    elif cfg.DATA.DATASET == 'lines':
        vae = True if 'vae' in cfg.DIRS.CHKP_PREFIX else False
        training_set = LinesDataset(args=cfg,
                                    multiplier=1000,
                                    dataset_type='train',
                                    vae=vae)
        val_set = LinesDataset(args=cfg,
                               multiplier=10,
                               dataset_type='test',
                               vae=vae)
        plot_set = LinesDataset(args=cfg,
                                multiplier=1,
                                dataset_type='plot',
                                vae=vae)

    if 'idec' in cfg.DIRS.CHKP_PREFIX and fine_tune:
        training_set = IdecDataset(training_set)
        val_set = IdecDataset(val_set)
        plot_set = IdecDataset(plot_set)

    return training_set, val_set, plot_set
Ejemplo n.º 7
0
def load_data(dataset, seed, args):
    if seed:
        # in normal method we do not implement random seed here
        # same group should share the same shuffling result
        torch.manual_seed(seed)
        random.seed(seed)
    if dataset == "MNIST":
        training_set = datasets.MNIST('./mnist_data',
                                      train=True,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_loader = torch.utils.data.DataLoader(training_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)
        test_loader = None
    elif dataset == "Cifar10":
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        # data prep for training set
        # note that the key point to reach convergence performance reported in this paper (https://arxiv.org/abs/1512.03385)
        # is to implement data augmentation
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(Variable(
                x.unsqueeze(0), requires_grad=False), (4, 4, 4, 4),
                                              mode='reflect').data.squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        # data prep for test set
        transform_test = transforms.Compose([transforms.ToTensor(), normalize])
        # load training and test set here:
        training_set = datasets.CIFAR10(root='./cifar10_data',
                                        train=True,
                                        download=True,
                                        transform=transform_train)
        train_loader = torch.utils.data.DataLoader(training_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)
        testset = datasets.CIFAR10(root='./cifar10_data',
                                   train=False,
                                   download=True,
                                   transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            testset, batch_size=args.test_batch_size, shuffle=False)
    elif args.dataset == 'SVHN':
        training_set = SVHN('./svhn_data',
                            split='train',
                            transform=transforms.Compose([
                                transforms.RandomCrop(32, padding=4),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                     (0.2023, 0.1994, 0.2010)),
                            ]))
        train_loader = torch.utils.data.DataLoader(training_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        testset = SVHN(root='./svhn_data',
                       split='test',
                       download=True,
                       transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            testset, batch_size=args.test_batch_size, shuffle=False)
    elif args.dataset == "Cifar100":
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(Variable(
                x.unsqueeze(0), requires_grad=False), (4, 4, 4, 4),
                                              mode='reflect').data.squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        # data prep for test set
        transform_test = transforms.Compose([transforms.ToTensor(), normalize])
        # load training and test set here:
        training_set = datasets.CIFAR100(root='./cifar100_data',
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        train_loader = torch.utils.data.DataLoader(training_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)
        testset = datasets.CIFAR100(root='./cifar100_data',
                                    train=False,
                                    download=True,
                                    transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            testset, batch_size=args.test_batch_size, shuffle=False)
    return train_loader, training_set, test_loader
def load_data(dataset, path, batch_size=64, normalize=False):
    if normalize:
        # Wasserstein BiGAN is trained on normalized data.
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    else:
        # BiGAN is trained on unnormalized data (see Dumoulin et al. ICLR 16).
        transform = transforms.ToTensor()

    if dataset == 'svhn':
        train_set = SVHN(path,
                         split='extra',
                         transform=transform,
                         download=True)
        val_set = SVHN(path, split='test', transform=transform, download=True)

    if dataset == 'stl10':
        train_set = STL10(path,
                          split='train',
                          transform=transform,
                          download=True)
        val_set = STL10(path, split='test', transform=transform, download=True)

    elif dataset == 'cifar10':
        train_set = CIFAR10(path,
                            train=True,
                            transform=transform,
                            download=True)
        val_set = CIFAR10(path,
                          train=False,
                          transform=transform,
                          download=True)

    elif dataset == 'stl10':
        train_set = STL10(path,
                          split='train',
                          transform=transform,
                          download=True)
        val_set = STL10(path, split='test', transform=transform, download=True)

    elif dataset == 'cifar100':
        train_set = CIFAR100(path,
                             train=True,
                             transform=transform,
                             download=True)
        val_set = CIFAR100(path,
                           train=False,
                           transform=transform,
                           download=True)

    elif dataset == 'VOC07':
        train_set = VOCSegmentation(path,
                                    image_set='train',
                                    year='2007',
                                    transform=transform,
                                    download=True)
        val_set = VOCSegmentation(path,
                                  image_set='val',
                                  year='2007',
                                  transform=transform,
                                  download=True)

    elif dataset == 'VOC10':
        train_set = VOCSegmentation(path,
                                    image_set='train',
                                    year='2010',
                                    transform=transform,
                                    download=True)
        val_set = VOCSegmentation(path,
                                  image_set='val',
                                  year='2010',
                                  transform=transform,
                                  download=True)

    train_loader = data.DataLoader(train_set,
                                   batch_size,
                                   shuffle=True,
                                   num_workers=12)
    val_loader = data.DataLoader(val_set,
                                 1,
                                 shuffle=False,
                                 num_workers=1,
                                 pin_memory=True)
    return train_loader, val_loader
def load_dataset(dataset):
    train_transform = T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomCrop(size=32, padding=4),
        T.ToTensor(),
        T.Normalize(
            [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
        )  # T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) # CIFAR-100
    ])

    test_transform = T.Compose([
        T.ToTensor(),
        T.Normalize(
            [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
        )  # T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) # CIFAR-100
    ])

    if dataset == 'cifar10':
        data_train = CIFAR10('../cifar10',
                             train=True,
                             download=True,
                             transform=train_transform)
        data_unlabeled = MyDataset(dataset, True, test_transform)
        data_test = CIFAR10('../cifar10',
                            train=False,
                            download=True,
                            transform=test_transform)
        NO_CLASSES = 10
        adden = ADDENDUM
        no_train = NUM_TRAIN
    elif dataset == 'cifar10im':
        data_train = CIFAR10('../cifar10',
                             train=True,
                             download=True,
                             transform=train_transform)
        #data_unlabeled   = CIFAR10('../cifar10', train=True, download=True, transform=test_transform)
        targets = np.array(data_train.targets)
        #NUM_TRAIN = targets.shape[0]
        classes, _ = np.unique(targets, return_counts=True)
        nb_classes = len(classes)
        imb_class_counts = [500, 5000] * 5
        class_idxs = [np.where(targets == i)[0] for i in range(nb_classes)]
        imb_class_idx = [
            class_id[:class_count]
            for class_id, class_count in zip(class_idxs, imb_class_counts)
        ]
        imb_class_idx = np.hstack(imb_class_idx)
        no_train = imb_class_idx.shape[0]
        # print(NUM_TRAIN)
        data_train.targets = targets[imb_class_idx]
        data_train.data = data_train.data[imb_class_idx]
        data_unlabeled = MyDataset(dataset[:-2], True, test_transform)
        data_unlabeled.cifar10.targets = targets[imb_class_idx]
        data_unlabeled.cifar10.data = data_unlabeled.cifar10.data[
            imb_class_idx]
        data_test = CIFAR10('../cifar10',
                            train=False,
                            download=True,
                            transform=test_transform)
        NO_CLASSES = 10
        adden = ADDENDUM
        no_train = NUM_TRAIN
    elif dataset == 'cifar100':
        data_train = CIFAR100('../cifar100',
                              train=True,
                              download=True,
                              transform=train_transform)
        data_unlabeled = MyDataset(dataset, True, test_transform)
        data_test = CIFAR100('../cifar100',
                             train=False,
                             download=True,
                             transform=test_transform)
        NO_CLASSES = 100
        adden = 2000
        no_train = NUM_TRAIN
    elif dataset == 'fashionmnist':
        data_train = FashionMNIST('../fashionMNIST',
                                  train=True,
                                  download=True,
                                  transform=T.Compose([T.ToTensor()]))
        data_unlabeled = MyDataset(dataset, True, T.Compose([T.ToTensor()]))
        data_test = FashionMNIST('../fashionMNIST',
                                 train=False,
                                 download=True,
                                 transform=T.Compose([T.ToTensor()]))
        NO_CLASSES = 10
        adden = ADDENDUM
        no_train = NUM_TRAIN
    elif dataset == 'svhn':
        data_train = SVHN('../svhn',
                          split='train',
                          download=True,
                          transform=T.Compose([T.ToTensor()]))
        data_unlabeled = MyDataset(dataset, True, T.Compose([T.ToTensor()]))
        data_test = SVHN('../svhn',
                         split='test',
                         download=True,
                         transform=T.Compose([T.ToTensor()]))
        NO_CLASSES = 10
        adden = ADDENDUM
        no_train = NUM_TRAIN
    return data_train, data_unlabeled, data_test, adden, NO_CLASSES, no_train
    def __init__(self,
                 base_dataset='cifar10',
                 take_amount=None,
                 take_amount_seed=13,
                 add_svhn_extra=False,
                 aux_data_filename=None,
                 add_aux_labels=False,
                 aux_take_amount=None,
                 train=False,
                 **kwargs):
        """A dataset with auxiliary pseudo-labeled data"""

        if base_dataset == 'cifar10':
            cifar10_path = get_CIFAR10_path()
            self.dataset = CIFAR10(root=cifar10_path,
                                   train=train,
                                   transform=kwargs['transform'])
        elif base_dataset == 'svhn':
            svhn_path = get_svhn_path()
            if train:
                self.dataset = SVHN(root=svhn_path,
                                    split='train',
                                    transform=kwargs['transform'])
            else:
                self.dataset = SVHN(root=svhn_path,
                                    split='test',
                                    transform=kwargs['transform'])
            # because torchvision is annoying
            self.dataset.targets = self.dataset.labels
            self.targets = list(self.targets)

            if train and add_svhn_extra:
                svhn_extra = SVHN(root=svhn_path,
                                  split='extra',
                                  transform=kwargs['transform'])
                self.data = np.concatenate([self.data, svhn_extra.data])
                self.targets.extend(svhn_extra.labels)
        else:
            raise ValueError('Dataset %s not supported' % base_dataset)
        self.base_dataset = base_dataset
        self.train = train

        if self.train:
            if take_amount is not None:
                rng_state = np.random.get_state()
                np.random.seed(take_amount_seed)
                take_inds = np.random.choice(len(self.sup_indices),
                                             take_amount,
                                             replace=False)
                np.random.set_state(rng_state)

                logger = logging.getLogger()
                logger.info(
                    'Randomly taking only %d/%d examples from training'
                    ' set, seed=%d, indices=%s', take_amount,
                    len(self.sup_indices), take_amount_seed, take_inds)
                self.targets = self.targets[take_inds]
                self.data = self.data[take_inds]

            self.sup_indices = list(range(len(self.targets)))
            self.unsup_indices = []

            if aux_data_filename is not None:
                aux_path = os.path.join(kwargs['root'], aux_data_filename)
                print("Loading data from %s" % aux_path)
                with open(aux_path, 'rb') as f:
                    aux = pickle.load(f)
                aux_data = aux['data']
                aux_targets = aux['extrapolated_targets']
                orig_len = len(self.data)

                if aux_take_amount is not None:
                    rng_state = np.random.get_state()
                    np.random.seed(take_amount_seed)
                    take_inds = np.random.choice(len(aux_data),
                                                 aux_take_amount,
                                                 replace=False)
                    np.random.set_state(rng_state)

                    logger = logging.getLogger()
                    logger.info(
                        'Randomly taking only %d/%d examples from aux data'
                        ' set, seed=%d, indices=%s', aux_take_amount,
                        len(aux_data), take_amount_seed, take_inds)
                    aux_data = aux_data[take_inds]
                    aux_targets = aux_targets[take_inds]

                self.data = np.concatenate((self.data, aux_data), axis=0)

                if not add_aux_labels:
                    self.targets.extend([-1] * len(aux_data))
                else:
                    self.targets.extend(aux_targets)
                # note that we use unsup indices to track the labeled datapoints
                # whose labels are "fake"
                self.unsup_indices.extend(
                    range(orig_len, orig_len + len(aux_data)))

            logger = logging.getLogger()
            logger.info("Training set")
            logger.info("Number of training samples: %d", len(self.targets))
            logger.info("Number of supervised samples: %d",
                        len(self.sup_indices))
            logger.info("Number of unsup samples: %d", len(self.unsup_indices))
            logger.info(
                "Label (and pseudo-label) histogram: %s",
                tuple(zip(*np.unique(self.targets, return_counts=True))))
            logger.info("Shape of training data: %s", np.shape(self.data))

        # Test set
        else:
            self.sup_indices = list(range(len(self.targets)))
            self.unsup_indices = []

            logger = logging.getLogger()
            logger.info("Test set")
            logger.info("Number of samples: %d", len(self.targets))
            logger.info(
                "Label histogram: %s",
                tuple(zip(*np.unique(self.targets, return_counts=True))))
            logger.info("Shape of data: %s", np.shape(self.data))
Ejemplo n.º 11
0
def get_svhn(location="./", batch_size=64, labels_per_class=1000, extra=True):
    from functools import reduce
    from operator import __or__
    from torch.utils.data.sampler import SubsetRandomSampler
    from torchvision.datasets import SVHN
    import torchvision.transforms as transforms
    from utils import onehot

    std = np.array([0.19653187, 0.19832356, 0.19942404])  # precalulated
    std = std.reshape(3, 1, 1)
    std = torch.tensor(std).float()

    def flatten(x):
        x = transforms.ToTensor()(x)
        # x += torch.rand(3, 32, 32) / 255.
        # x /= std
        return x.view(-1)

    svhn_train = SVHN(location,
                      split="train",
                      download=True,
                      transform=flatten,
                      target_transform=onehot(n_labels))

    if extra:
        svhn_extra = SVHN(location,
                          split="extra",
                          download=True,
                          transform=flatten,
                          target_transform=onehot(n_labels))
        svhn_train = torch.utils.data.ConcatDataset([svhn_train, svhn_extra])

    print("Len of svhn train", len(svhn_train))
    svhn_valid = SVHN(location,
                      split="test",
                      download=True,
                      transform=flatten,
                      target_transform=onehot(n_labels))

    def uniform_sampler(labels, n=None):
        # Only choose digits in n_labels
        (indices, ) = np.where(
            reduce(__or__, [labels == i for i in np.arange(n_labels)]))

        # Ensure uniform distribution of labels
        np.random.shuffle(indices)
        indices = np.hstack([
            list(filter(lambda idx: labels[idx] == i, indices))[:n]
            for i in range(n_labels)
        ])

        indices = torch.from_numpy(indices)
        sampler = SubsetRandomSampler(indices)
        return sampler

    def get_sampler(dataset_len, n=None):
        # Only choose digits in n_labels
        # (indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(n_labels)]))

        # Ensure uniform distribution of labels
        # np.random.shuffle(indices)
        # indices = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:n] for i in range(n_labels)])

        if n is None:
            indices = np.arange(dataset_len)
        else:
            indices = np.random.choice(dataset_len, size=n * 10, replace=False)
        indices = torch.from_numpy(indices)
        sampler = SubsetRandomSampler(indices)
        return sampler

    # Dataloaders for MNIST
    labelled = torch.utils.data.DataLoader(svhn_train,
                                           batch_size=batch_size,
                                           num_workers=2,
                                           pin_memory=cuda,
                                           sampler=get_sampler(
                                               len(svhn_train),
                                               labels_per_class))
    unlabelled = torch.utils.data.DataLoader(svhn_train,
                                             batch_size=batch_size,
                                             num_workers=2,
                                             pin_memory=cuda)
    validation = torch.utils.data.DataLoader(svhn_valid,
                                             batch_size=batch_size,
                                             num_workers=2,
                                             pin_memory=cuda)
    return labelled, unlabelled, validation, std