def get_dataloaders(dataset, batch, dataroot, split=0.0, split_idx=0, horovod=False): if 'cifar' in dataset: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) elif 'imagenet' in dataset: transform_train = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2, ), transforms.ToTensor(), # Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if C.get()['model']['type'] == 'resnet200': # Instead, we test a single 320×320 crop from s = 320 transform_test = transforms.Compose([ transforms.Resize(320), transforms.CenterCrop(320), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: transform_test = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: raise ValueError('dataset=%s' % dataset) if isinstance(C.get()['aug'], list): logger.debug('augmentation provided.') transform_train.transforms.insert(0, Augmentation(C.get()['aug'])) else: logger.debug('augmentation: %s' % C.get()['aug']) if C.get()['aug'] == 'random2048': transform_train.transforms.insert( 0, Augmentation(random_search2048())) elif C.get()['aug'] == 'fa_reduced_cifar10': transform_train.transforms.insert( 0, Augmentation(fa_reduced_cifar10())) elif C.get()['aug'] == 'fa_reduced_imagenet': transform_train.transforms.insert( 0, Augmentation(fa_reduced_imagenet())) elif C.get()['aug'] == 'arsaug': transform_train.transforms.insert(0, Augmentation(arsaug_policy())) elif C.get()['aug'] == 'autoaug_cifar10': transform_train.transforms.insert( 0, Augmentation(autoaug_paper_cifar10())) elif C.get()['aug'] == 'autoaug_extend': transform_train.transforms.insert(0, Augmentation(autoaug_policy())) elif C.get()['aug'] in ['default', 'inception', 'inception320']: pass else: raise ValueError('not found augmentations. %s' % C.get()['aug']) if C.get()['cutout'] > 0: transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) if dataset == 'cifar10': total_trainset = torchvision.datasets.CIFAR10( root=dataroot, train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'reduced_cifar10': total_trainset = torchvision.datasets.CIFAR10( root=dataroot, train=True, download=True, transform=transform_train) sss = StratifiedShuffleSplit(n_splits=1, test_size=46000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.train_labels) train_idx, valid_idx = next(sss) train_labels = [total_trainset.train_labels[idx] for idx in train_idx] total_trainset = Subset(total_trainset, train_idx) total_trainset.train_labels = train_labels testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'cifar100': total_trainset = torchvision.datasets.CIFAR100( root=dataroot, train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'imagenet': total_trainset = torchvision.datasets.ImageFolder( root=os.path.join(dataroot, 'imagenet/train'), transform=transform_train) testset = torchvision.datasets.ImageFolder(root=os.path.join( dataroot, 'imagenet/val'), transform=transform_test) # compatibility total_trainset.train_labels = [lb for _, lb in total_trainset.samples] elif dataset == 'reduced_imagenet': # randomly chosen indices idx120 = [ 904, 385, 759, 884, 784, 844, 132, 214, 990, 786, 979, 582, 104, 288, 697, 480, 66, 943, 308, 282, 118, 926, 882, 478, 133, 884, 570, 964, 825, 656, 661, 289, 385, 448, 705, 609, 955, 5, 703, 713, 695, 811, 958, 147, 6, 3, 59, 354, 315, 514, 741, 525, 685, 673, 657, 267, 575, 501, 30, 455, 905, 860, 355, 911, 24, 708, 346, 195, 660, 528, 330, 511, 439, 150, 988, 940, 236, 803, 741, 295, 111, 520, 856, 248, 203, 147, 625, 589, 708, 201, 712, 630, 630, 367, 273, 931, 960, 274, 112, 239, 463, 355, 955, 525, 404, 59, 981, 725, 90, 782, 604, 323, 418, 35, 95, 97, 193, 690, 869, 172 ] total_trainset = torchvision.datasets.ImageFolder( root=os.path.join(dataroot, 'imagenet/train'), transform=transform_train) testset = torchvision.datasets.ImageFolder(root=os.path.join( dataroot, 'imagenet/val'), transform=transform_test) # compatibility total_trainset.train_labels = [lb for _, lb in total_trainset.samples] sss = StratifiedShuffleSplit(n_splits=1, test_size=len(total_trainset) - 500000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.train_labels) train_idx, valid_idx = next(sss) # filter out train_idx = list( filter(lambda x: total_trainset.train_labels[x] in idx120, train_idx)) valid_idx = list( filter(lambda x: total_trainset.train_labels[x] in idx120, valid_idx)) test_idx = list( filter(lambda x: testset.samples[x][1] in idx120, range(len(testset)))) train_labels = [ idx120.index(total_trainset.train_labels[idx]) for idx in train_idx ] for idx in range(len(total_trainset.samples)): if total_trainset.samples[idx][1] not in idx120: continue total_trainset.samples[idx] = (total_trainset.samples[idx][0], idx120.index( total_trainset.samples[idx][1])) total_trainset = Subset(total_trainset, train_idx) total_trainset.train_labels = train_labels for idx in range(len(testset.samples)): if testset.samples[idx][1] not in idx120: continue testset.samples[idx] = (testset.samples[idx][0], idx120.index(testset.samples[idx][1])) testset = Subset(testset, test_idx) print('reduced_imagenet train=', len(total_trainset)) else: raise ValueError('invalid dataset name=%s' % dataset) if split > 0.0: sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0) sss = sss.split(list(range(len(total_trainset))), total_trainset.train_labels) for _ in range(split_idx + 1): train_idx, valid_idx = next(sss) train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetSampler(valid_idx) if horovod: import horovod.torch as hvd train_sampler = torch.utils.data.distributed.DistributedSampler( train_sampler, num_replicas=hvd.size(), rank=hvd.rank()) else: valid_sampler = SubsetSampler([]) if horovod: import horovod.torch as hvd train_sampler = DistributedStratifiedSampler( total_trainset.train_labels, num_replicas=hvd.size(), rank=hvd.rank()) else: train_sampler = StratifiedSampler(total_trainset.train_labels) trainloader = torch.utils.data.DataLoader( total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=32, pin_memory=True, sampler=train_sampler, drop_last=True) validloader = torch.utils.data.DataLoader(total_trainset, batch_size=batch, shuffle=False, num_workers=16, pin_memory=True, sampler=valid_sampler, drop_last=False) testloader = torch.utils.data.DataLoader(testset, batch_size=batch, shuffle=False, num_workers=32, pin_memory=True, drop_last=False) return train_sampler, trainloader, validloader, testloader
def load_dataset(data_dir, resize, dataset_name, img_type): if dataset_name == 'cifar_10': mean = cifar_10['mean'] std = cifar_10['std'] elif dataset_name == 'cifar_100': mean = cifar_100['mean'] std = cifar_100['std'] else: print( 'Dataset not recognized. Data normalize with equal mean/std weights' ) mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] hdf5_folder = '{}/hdf5'.format(data_dir) if os.path.exists(hdf5_folder): shutil.rmtree(hdf5_folder) create_hdf5(data_dir, resize, dataset_name, img_type) train_transform = transforms.Compose([ transforms.Pad(4), transforms.RandomCrop(resize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) test_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) if isinstance(C.get()['aug'], list): logger.debug('augmentation provided.') train_transform.transforms.insert(0, Augmentation(C.get()['aug'])) else: logger.debug('augmentation: %s' % C.get()['aug']) if C.get()['aug'] == 'random2048': train_transform.transforms.insert( 0, Augmentation(random_search2048())) elif C.get()['aug'] == 'fa_reduced_cifar10': train_transform.transforms.insert( 0, Augmentation(fa_reduced_cifar10())) elif C.get()['aug'] == 'fa_reduced_imagenet': train_transform.transforms.insert( 0, Augmentation(fa_reduced_imagenet())) elif C.get()['aug'] == 'arsaug': train_transform.transforms.insert(0, Augmentation(arsaug_policy())) elif C.get()['aug'] == 'autoaug_cifar10': train_transform.transforms.insert( 0, Augmentation(autoaug_paper_cifar10())) elif C.get()['aug'] == 'autoaug_extend': train_transform.transforms.insert(0, Augmentation(autoaug_policy())) elif C.get()['aug'] in ['default', 'inception', 'inception320']: pass else: raise ValueError('not found augmentations. %s' % C.get()['aug']) if C.get()['cutout'] > 0: train_transform.transforms.append(CutoutDefault(C.get()['cutout'])) hdf5_folder = '{}/hdf5'.format(data_dir) hdf5_train_path = '{}/{}_{}.hdf5'.format(hdf5_folder, dataset_name, 'training') hdf5_test_path = '{}/{}_{}.hdf5'.format(hdf5_folder, dataset_name, 'test') train_dataset = CustomDataset(hdf5_file=hdf5_train_path, transform=train_transform) val_dataset = CustomDataset(hdf5_file=hdf5_train_path, transform=test_transform) test_dataset = CustomDataset(hdf5_file=hdf5_test_path, transform=test_transform) train_dataset.train_labels = train_dataset.labels_id return [train_dataset, val_dataset, test_dataset]