示例#1
0
def get_cifar100_loader(mode, root_path):
    '''
	获取cifar10 loader
	'''
    if mode == "train":
        dst = CIFAR100(root=root_path,
                       train=True,
                       transform=train_transform,
                       download=False)
        data_loader = DataLoader(dst,
                                 batch_size=config.batch_size,
                                 shuffle=True,
                                 num_workers=config.num_workers)
    elif mode == "validate":
        dst = CIFAR100(root=root_path,
                       train=False,
                       transform=validate_transform,
                       download=False)
        data_loader = DataLoader(dst,
                                 batch_size=config.batch_size,
                                 shuffle=False,
                                 num_workers=config.num_workers)
    else:
        raise ValueError("get_loader mode is error")
    return data_loader
示例#2
0
def download(dataset_cfg):
    train = CIFAR100('/tmp',
                     train=True,
                     transform=None,
                     target_transform=None,
                     download=True)
    test = CIFAR100('/tmp',
                    train=False,
                    transform=None,
                    target_transform=None,
                    download=True)

    img_root = dataset_cfg.image_root_folder
    os.makedirs(img_root, exist_ok=True)
    os.makedirs(os.path.dirname(dataset_cfg.label_file), exist_ok=True)
    label_set = set(train.class_to_idx.values())

    for label in label_set:
        os.makedirs(os.path.join(img_root, str(label)), exist_ok=True)
    d = dict()
    for i, (img, label) in enumerate(train + test):
        total_label = np.zeros(len(label_set))
        total_label[label] = 1
        d[i] = total_label
        img.save(os.path.join(img_root, str(label), '%s.jpg' % i))

    np.save(dataset_cfg.label_file, d)
示例#3
0
def load_data(data_root_dir='../data/'):
    train_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.RandomErasing(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 测试阶段 Ten Crop test
    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.TenCrop(224),
        transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
        transforms.Lambda(lambda crops: torch.stack(
            [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(crop) for crop in crops]))
    ])

    data_loaders = {}
    data_sizes = {}
    for name in ['train', 'test']:
        if name == 'train':
            data_set = CIFAR100(data_root_dir, train=True, download=True, transform=train_transform)
            data_loader = DataLoader(data_set, batch_size=96, shuffle=True, num_workers=8)
        else:
            data_set = CIFAR100(data_root_dir, train=False, download=True, transform=test_transform)
            data_loader = DataLoader(data_set, batch_size=48, shuffle=True, num_workers=8)
        data_loaders[name] = data_loader
        data_sizes[name] = len(data_set)
    return data_loaders, data_sizes
def get_dataset(cls, cutout_length=0):
    MEAN = [0.5071, 0.4865, 0.4409]
    STD = [0.1942, 0.1918, 0.1958]

    cutout = []
    if cutout_length > 0:
        cutout.append(Cutout(cutout_length))

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop((32, 32)),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD)
    ])
    valid_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(MEAN, STD)])

    if cls == "cifar100":
        dataset_train = CIFAR100(root="./data",
                                 train=True,
                                 download=True,
                                 transform=train_transform)
        dataset_valid = CIFAR100(root="./data",
                                 train=False,
                                 download=True,
                                 transform=valid_transform)
    else:
        raise NotImplementedError
    return dataset_train, dataset_valid
示例#5
0
def load_cifar_datasets(path='./dataset',
                        n_class=10,
                        train_transform=get_transform()[0],
                        test_transform=get_transform()[1]):
    if n_class == 10:
        train_dataset = CIFAR10(path,
                                train=True,
                                download=True,
                                transform=train_transform)
        test_dataset = CIFAR10(path,
                               train=False,
                               download=True,
                               transform=test_transform)
    elif n_class == 100:
        train_dataset = CIFAR100(path,
                                 train=True,
                                 download=True,
                                 transform=train_transform)
        test_dataset = CIFAR100(path,
                                train=False,
                                download=True,
                                transform=test_transform)
    else:
        train_dataset, test_dataset = None, None
    return train_dataset, test_dataset
def get_dataset(cls, cutout_length=0):
    MEAN = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
    STD = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]
    transf = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip()
    ]
    normalize = [transforms.ToTensor(), transforms.Normalize(MEAN, STD)]
    cutout = []
    if cutout_length > 0:
        cutout.append(Cutout(cutout_length))

    train_transform = transforms.Compose(transf + normalize + cutout)
    valid_transform = transforms.Compose(normalize)

    if cls == "cifar100":
        dataset_train = CIFAR100(root="./data",
                                 train=True,
                                 download=True,
                                 transform=train_transform)
        dataset_valid = CIFAR100(root="./data",
                                 train=False,
                                 download=True,
                                 transform=valid_transform)
    else:
        raise NotImplementedError
    return dataset_train, dataset_valid
示例#7
0
def load_data(dataset, path, batch_size=64, normalize=False):
  if normalize:
    # Wasserstein BiGAN is trained on normalized data.
    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  else:
    # BiGAN is trained on unnormalized data (see Dumoulin et al. ICLR 16).
    transform = transforms.ToTensor()

  if dataset == 'svhn':
    train_set = SVHN(path, split='extra', transform=transform, download=True)
    val_set = SVHN(path, split='test', transform=transform, download=True)

  elif dataset == 'cifar10':
    train_set = CIFAR10(path, train=True, transform=transform, download=True)
    val_set = CIFAR10(path, train=False, transform=transform, download=True)

  elif dataset == 'cifar100':
    train_set = CIFAR100(path, train=True, transform=transform, download=True)
    val_set = CIFAR100(path, train=False, transform=transform, download=True)

  train_loader = data.DataLoader(
    train_set, batch_size, shuffle=True, num_workers=12)
  val_loader = data.DataLoader(
    val_set, 1, shuffle=False, num_workers=1, pin_memory=True)
  return train_loader, val_loader
示例#8
0
def get_cifar100():
    n_classes = 100
    i_channel = 3
    i_dim = 32
    # transforms taken from https://github.com/kuangliu/pytorch-cifar/blob/master/main.py
    # transform_train = transforms.Compose([transforms.ToTensor(),\
    #                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), \
                                          transforms.RandomHorizontalFlip(), transforms.ToTensor(), \
                                          transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])

    # transform_test = transforms.Compose([transforms.ToTensor(),\
    #                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    transform_test = transforms.Compose([transforms.ToTensor(), \
                                         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])

    train_d = CIFAR100(root=D_PTH,
                       train=True,
                       download=True,
                       transform=transform_train)
    # self.train_d.data = self.train_d.data.reshape(len(self.train_d),self.i_dim)
    test_d = CIFAR100(root=D_PTH,
                      train=False,
                      download=True,
                      transform=transform_test)

    return (n_classes, i_channel, i_dim, train_d, test_d)
示例#9
0
文件: dataset.py 项目: romech/fact-ai
 def setup(self, stage='fit'):
     if self.superclass:
         if stage == 'fit':
             self.cifar_train = CIFAR100Super(
                 self.data_path,
                 train=True,
                 transform=self.train_transforms)
             self.cifar_val = CIFAR100Super(self.data_path,
                                            train=False,
                                            transform=self.test_transforms)
         elif stage == 'test':
             self.cifar_test = CIFAR100Super(self.data_path,
                                             train=False,
                                             transform=self.test_transforms)
     else:
         if stage == 'fit':
             self.cifar_train = CIFAR100(self.data_path,
                                         train=True,
                                         transform=self.train_transforms)
             self.cifar_val = CIFAR100(self.data_path,
                                       train=False,
                                       transform=self.test_transforms)
         elif stage == 'test':
             self.cifar_test = CIFAR100(self.data_path,
                                        train=False,
                                        transform=self.test_transforms)
示例#10
0
def get_dataset(data_path, dataset):
    if dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    if dataset == 'cifar10':
        train_data = CIFAR10(data_path,
                             train=True,
                             transform=train_transform,
                             download=True)
        test_data = CIFAR10(data_path,
                            train=False,
                            transform=test_transform,
                            download=True)
    elif dataset == 'cifar100':
        train_data = CIFAR100(data_path,
                              train=True,
                              transform=train_transform,
                              download=True)
        test_data = CIFAR100(data_path,
                             train=False,
                             transform=test_transform,
                             download=True)
    return train_data, test_data
示例#11
0
    def __init__(self, args):
        # pin_memory = False
        # if args.gpu is not None:
        pin_memory = True

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        trainset = CIFAR100(root=args.data_path, train=True, download=True, transform=transform_train)

        self.trainLoader = DataLoader(
            trainset, batch_size=args.train_batch_size, shuffle=True,
            num_workers=2, pin_memory=pin_memory
        )

        testset = CIFAR100(root=args.data_path, train=False, download=False, transform=transform_test)
        self.testLoader = DataLoader(
            testset, batch_size=args.eval_batch_size, shuffle=False,
            num_workers=2, pin_memory=pin_memory)
def get_data_loader(data_path, batch_size):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), # ref to EigenDamage code
        # transforms.Normalize((0.4914, 0.4822, 0.4465),
        #                      (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        # transforms.Normalize((0.4914, 0.4822, 0.4465),
        #                      (0.2023, 0.1994, 0.2010)),
    ])

    train_set = CIFAR100(data_path,
                         train=True,
                         download=True,
                         transform=transform_train)
    test_set = CIFAR100(data_path,
                        train=False,
                        download=True,
                        transform=transform_test)

    return train_set, test_set
示例#13
0
    def __init__(self, data_type, model_type, num_groups, num_query,
                 num_epoch, sampling_type='nbs', save_name='0'):
        self.data_type = data_type
        self.model_type = model_type
        self.num_groups = num_groups
        self.num_query = num_query
        self.num_epoch = num_epoch
        self.sampling_type = sampling_type
        self.save_name = save_name

        if data_type == 'cifar10':
            self.dataset = Dataset(
                CIFAR10(root='.cifar10', train=True, download=True,
                        transform=get_transform(32, 4, 16)['train'])
            )
            self.testset = Dataset(
                CIFAR10(root='.cifar10', train=False, download=True,
                        transform=get_transform(32, 4, 16)['test'])
            )
        else:
            self.dataset = Dataset(
                CIFAR100(root='.cifar100', train=True, download=True,
                         transform=get_transform(32, 4, 8)['train'])
            )
            self.testset = Dataset(
                CIFAR100(root='.cifar100', train=False, download=True,
                         transform=get_transform(32, 4, 8)['test'])
            )
        self.test_loader = DataLoader(self.testset, batch_size=6144,
                                      num_workers=4, pin_memory=True)
        self.indice = list(range(len(self.dataset)))
        random.Random(0).shuffle(self.indice)
示例#14
0
def get_dataset(dataset_root, dataset, train_transform, test_transform):
    if dataset == 'cifar10':
        train = CIFAR10(dataset_root,
                        train=True,
                        download=True,
                        transform=train_transform)
        unlabeled = CIFAR10(dataset_root,
                            train=True,
                            download=True,
                            transform=train_transform)
        test = CIFAR10(dataset_root,
                       train=False,
                       download=True,
                       transform=test_transform)
    elif dataset == 'cifar100':
        train = CIFAR100(dataset_root,
                         train=True,
                         download=True,
                         transform=train_transform)
        unlabeled = CIFAR100(dataset_root,
                             train=True,
                             download=True,
                             transform=train_transform)
        test = CIFAR100(dataset_root,
                        train=False,
                        download=True,
                        transform=test_transform)
    else:
        print("Error: No dataset named {}!".format(dataset))
        return -1
    return train, test, unlabeled
def get_cifar100():
    pin_memory = True

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = CIFAR100(root='./data', train=True, download=True, transform=transform_train)

    train_loader = DataLoader(
        trainset, batch_size=128, shuffle=True,
        num_workers=4, pin_memory=pin_memory
    )

    testset = CIFAR100(root='./data', train=False, download=False, transform=transform_test)
    test_loader = DataLoader(
        testset, batch_size=128, shuffle=False,
        num_workers=4, pin_memory=pin_memory)
    
    return train_loader, test_loader
示例#16
0
def get_train_val_test_datasets(
    rnd: np.random.RandomState,
    root='~/data',
    validation_ratio=0.05,
) -> tuple:
    """
    Create CIFAR-100 train/val/test data loaders

    :param rnd: `np.random.RandomState` instance.
    :param validation_ratio: The ratio of validation data. If this value is `0.`, returned `val_set` is `None`.
    :param root: Path to save data.

    :return: Tuple of (train, val, test) or (train, test).
    """

    transform = transforms.Compose([transforms.ToTensor()])

    train_set = CIFAR100(root=root,
                         train=True,
                         download=True,
                         transform=transform)

    # create validation split
    if validation_ratio > 0.:
        train_set, val_set = _train_val_split(
            rnd=rnd,
            train_dataset=train_set,
            validation_ratio=validation_ratio)

    # create a transform to do pre-processing
    train_loader = DataLoader(
        train_set,
        batch_size=len(train_set),
        shuffle=False,
    )

    data = iter(train_loader).next()
    dim = [0, 2, 3]
    mean = data[0].mean(dim=dim).numpy()
    std = data[0].std(dim=dim).numpy()
    # end of creating a transform to do pre-processing

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_set.transform = transform

    if validation_ratio > 0.:
        val_set.transform = transform
    else:
        val_set = None

    test_set = CIFAR100(root=root,
                        train=False,
                        download=True,
                        transform=transform)

    return train_set, val_set, test_set
示例#17
0
def CIFAR100_datagenerator(batch_size, size):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

    cifar10_train = CIFAR100(root='./data',
                             train=True,
                             transform=transform_train,
                             download=True)
    cifar10_test = CIFAR100(root='./data',
                            train=False,
                            transform=transform_test,
                            download=True)
    train_loader = DataLoader(dataset=cifar10_train,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=False)
    test_loader = DataLoader(dataset=cifar10_test,
                             batch_size=batch_size,
                             shuffle=False,
                             drop_last=False)

    return train_loader, test_loader
示例#18
0
    def __init__(self, opt, val=False):
        super(CustomCIFAR100, self).__init__()
        dir_dataset = opt.dir_dataset

        if val:
            self.dataset = CIFAR100(root=dir_dataset,
                                    train=False,
                                    download=True)
            self.transform = Compose([
                ToTensor(),
                Normalize(mean=[0.507, 0.487, 0.441],
                          std=[0.267, 0.256, 0.276])
            ])

        else:
            self.dataset = CIFAR100(root=dir_dataset,
                                    train=True,
                                    download=True)
            self.transform = Compose([
                RandomCrop((32, 32),
                           padding=4,
                           fill=0,
                           padding_mode='constant'),
                RandomHorizontalFlip(),
                ToTensor(),
                Normalize(mean=[0.507, 0.487, 0.441],
                          std=[0.267, 0.256, 0.276])
            ])
示例#19
0
def get_split_cifar(bsz, num_task):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    trainset = CIFAR100('data/cifar_data',
                        train=True,
                        download=True,
                        transform=train_transform)
    testset = CIFAR100('data/cifars_data',
                       train=False,
                       download=True,
                       transform=test_transform)

    workers = list(range(num_task))
    partitioner = partition_dataset(trainset, testset, workers, True)
    train_loader, test_loader, labels = {}, {}, {}
    for i in workers:
        train_loader[i], test_loader[i], labels[i] = select_dataset(
            workers, i, partitioner, bsz)
    return train_loader, test_loader, labels
示例#20
0
def Dataset(dataset):
    trainset, testset = None, None
    if dataset == 'cifar10' or 'cifar100':
        tra_trans = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            # transforms.Grayscale(num_output_channels=1),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        val_trans = transforms.Compose([
            # transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        if dataset == 'cifar10':
            trainset = CIFAR10(root="/home/hyf/data", train=True,
                               download=True, transform=tra_trans)
            testset = CIFAR10(root="/home/hyf/data", train=False,
                              download=True, transform=val_trans)
        if dataset == 'cifar100':
            trainset = CIFAR100(root="/home/hyf/data", train=True,
                                download=True, transform=tra_trans)
            testset = CIFAR100(root="/home/hyf/data", train=False,
                               download=True, transform=val_trans)

    if dataset == 'femnist' or 'mnist' or 'fashonmnist':
        tra_trans = transforms.Compose([
            transforms.Pad(2, padding_mode='edge'),
            # transforms.ToPILImage(),
            # transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            transforms.Normalize((0.1307,), (0.3081,)),
            # transforms.Normalize((0.4914, 0.4822, 0.4465),
            #                      (0.2023, 0.1994, 0.2010)),
        ])
        val_trans = transforms.Compose([
            transforms.Pad(2, padding_mode='edge'),
            # transforms.ToPILImage(),
            # transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
            # transforms.Normalize((0.4914, 0.4822, 0.4465),
            #                      (0.2023, 0.1994, 0.2010)),
        ])
        if dataset == 'femnist':
            trainset = FEMNIST(root='~/data', train=True, download=True, transform=tra_trans)
            testset = FEMNIST(root='~/data', train=False, download=True, transform=val_trans)
        if dataset == 'mnist':
            trainset = MNIST(root='/home/hyf/data', train=True, download=True, transform=tra_trans)
            testset = MNIST(root='/home/hyf/data', train=False, download=True, transform=val_trans)
        if dataset == 'fashonmnist':
            trainset = FashionMNIST(root='~/data', train=True, download=True, transform=tra_trans)
            testset = FashionMNIST(root='~/data', train=False, download=True, transform=val_trans)

    return trainset, testset
示例#21
0
def get_single_task(dataroot, task):
    trainset = CIFAR100(dataroot, train=True, transform=transforms.ToTensor())
    trainset = filter_by_coarse_label(trainset, task)

    testset = CIFAR100(dataroot, train=False, transform=transforms.ToTensor())
    testset = filter_by_coarse_label(testset, task)

    return trainset, testset
示例#22
0
def _get_cifar100_dataset(dataset_root):
    if dataset_root is None:
        dataset_root = default_dataset_location('cifar100')

    train_set = CIFAR100(dataset_root, train=True, download=True)
    test_set = CIFAR100(dataset_root, train=False, download=True)

    return train_set, test_set
示例#23
0
文件: cifar100.py 项目: zhwzhong/vega
 def __init__(self, **kwargs):
     """Construct the Cifar100 class.."""
     Dataset.__init__(self, **kwargs)
     CIFAR100.__init__(self,
                       root=self.args.data_path,
                       train=self.train,
                       transform=Compose(self.transforms.__transform__),
                       download=self.args.download)
def load_cifar100(root='data/', batch_size=128, download=True, num_workers=0, drop_last=False):
    cifar100_train = CIFAR100(root, train=True, transform=cifar100_train_transforms, download=download)
    cifar100_test = CIFAR100(root, train=False, transform=cifar100_test_transforms, download=download)
    
    cifar100_train_dataloader = DataLoader(cifar100_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=drop_last)
    cifar100_test_dataloader = DataLoader(cifar100_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return cifar100_train_dataloader, cifar100_test_dataloader
示例#25
0
def cifar100_altaug(transform_t, transform_v):
    dataset_train = CIFAR100(root=os.path.expanduser('~/Datasets/cifar100'),
                             train=True,
                             transform=transform_t,
                             download=True)
    dataset_val = CIFAR100(root=os.path.expanduser('~/Datasets/cifar100'),
                           train=False,
                           transform=transform_v)
    return dataset_train, dataset_val
示例#26
0
文件: data.py 项目: GaiYu0/f1
def load_cifar100():
    a = CIFAR100(root='CIFAR100', train=True)
    ax, ay = a.train_data, a.train_labels
    b = CIFAR100(root='CIFAR100', train=False)
    bx, by = b.test_data, b.test_labels
    x, y = np.concatenate([ax, bx]), np.concatenate([ay, by])
    x = x.transpose([0, 3, 1, 2]).reshape([len(x), -1])
    x, y = th.from_numpy(x).float(), th.from_numpy(y)
    return x, y
示例#27
0
def _get_cifar100_dataset(train_transformation, eval_transformation):
    train_set = CIFAR100(expanduser("~") + "/.avalanche/data/cifar100/",
                         train=True, download=True)

    test_set = CIFAR100(expanduser("~") + "/.avalanche/data/cifar100/",
                        train=False, download=True)

    return train_eval_avalanche_datasets(
        train_set, test_set, train_transformation, eval_transformation)
示例#28
0
def _get_cifar100_dataset():
    train_set = CIFAR100(expanduser("~") + "/.avalanche/data/cifar100/",
                         train=True,
                         download=True)

    test_set = CIFAR100(expanduser("~") + "/.avalanche/data/cifar100/",
                        train=False,
                        download=True)

    return train_set, test_set
示例#29
0
    def build(self):
        train_dt = CIFAR100(self.data_dir,
                            transform=self.train_trans,
                            download=True)
        test_dt = CIFAR100(self.data_dir,
                           train=False,
                           transform=self.test_trans,
                           download=True)

        return train_dt, test_dt
def cifar100(root):
    from torchvision.datasets import CIFAR100
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761)),
    ])
    trainset = CIFAR100(root, train=True, transform=transform, download=True)
    testset = CIFAR100(root, train=False, transform=transform)
    return trainset, testset