def get_dataloaders(dataset, batch, dataroot, split=0.0, split_idx=0, horovod=False): if 'cifar' in dataset: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) elif 'imagenet' in dataset: transform_train = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2, ), transforms.ToTensor(), # Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if C.get()['model']['type'] == 'resnet200': # Instead, we test a single 320×320 crop from s = 320 transform_test = transforms.Compose([ transforms.Resize(320), transforms.CenterCrop(320), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: transform_test = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: raise ValueError('dataset=%s' % dataset) if isinstance(C.get()['aug'], list): logger.debug('augmentation provided.') transform_train.transforms.insert(0, Augmentation(C.get()['aug'])) else: logger.debug('augmentation: %s' % C.get()['aug']) if C.get()['aug'] == 'random2048': transform_train.transforms.insert( 0, Augmentation(random_search2048())) elif C.get()['aug'] == 'fa_reduced_cifar10': transform_train.transforms.insert( 0, Augmentation(fa_reduced_cifar10())) elif C.get()['aug'] == 'fa_reduced_imagenet': transform_train.transforms.insert( 0, Augmentation(fa_reduced_imagenet())) elif C.get()['aug'] == 'arsaug': transform_train.transforms.insert(0, Augmentation(arsaug_policy())) elif C.get()['aug'] == 'autoaug_cifar10': transform_train.transforms.insert( 0, Augmentation(autoaug_paper_cifar10())) elif C.get()['aug'] == 'autoaug_extend': transform_train.transforms.insert(0, Augmentation(autoaug_policy())) elif C.get()['aug'] in ['default', 'inception', 'inception320']: pass else: raise ValueError('not found augmentations. %s' % C.get()['aug']) if C.get()['cutout'] > 0: transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) if dataset == 'cifar10': total_trainset = torchvision.datasets.CIFAR10( root=dataroot, train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'reduced_cifar10': total_trainset = torchvision.datasets.CIFAR10( root=dataroot, train=True, download=True, transform=transform_train) sss = StratifiedShuffleSplit(n_splits=1, test_size=46000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.train_labels) train_idx, valid_idx = next(sss) train_labels = [total_trainset.train_labels[idx] for idx in train_idx] total_trainset = Subset(total_trainset, train_idx) total_trainset.train_labels = train_labels testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'cifar100': total_trainset = torchvision.datasets.CIFAR100( root=dataroot, train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'imagenet': total_trainset = torchvision.datasets.ImageFolder( root=os.path.join(dataroot, 'imagenet/train'), transform=transform_train) testset = torchvision.datasets.ImageFolder(root=os.path.join( dataroot, 'imagenet/val'), transform=transform_test) # compatibility total_trainset.train_labels = [lb for _, lb in total_trainset.samples] elif dataset == 'reduced_imagenet': # randomly chosen indices idx120 = [ 904, 385, 759, 884, 784, 844, 132, 214, 990, 786, 979, 582, 104, 288, 697, 480, 66, 943, 308, 282, 118, 926, 882, 478, 133, 884, 570, 964, 825, 656, 661, 289, 385, 448, 705, 609, 955, 5, 703, 713, 695, 811, 958, 147, 6, 3, 59, 354, 315, 514, 741, 525, 685, 673, 657, 267, 575, 501, 30, 455, 905, 860, 355, 911, 24, 708, 346, 195, 660, 528, 330, 511, 439, 150, 988, 940, 236, 803, 741, 295, 111, 520, 856, 248, 203, 147, 625, 589, 708, 201, 712, 630, 630, 367, 273, 931, 960, 274, 112, 239, 463, 355, 955, 525, 404, 59, 981, 725, 90, 782, 604, 323, 418, 35, 95, 97, 193, 690, 869, 172 ] total_trainset = torchvision.datasets.ImageFolder( root=os.path.join(dataroot, 'imagenet/train'), transform=transform_train) testset = torchvision.datasets.ImageFolder(root=os.path.join( dataroot, 'imagenet/val'), transform=transform_test) # compatibility total_trainset.train_labels = [lb for _, lb in total_trainset.samples] sss = StratifiedShuffleSplit(n_splits=1, test_size=len(total_trainset) - 500000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.train_labels) train_idx, valid_idx = next(sss) # filter out train_idx = list( filter(lambda x: total_trainset.train_labels[x] in idx120, train_idx)) valid_idx = list( filter(lambda x: total_trainset.train_labels[x] in idx120, valid_idx)) test_idx = list( filter(lambda x: testset.samples[x][1] in idx120, range(len(testset)))) train_labels = [ idx120.index(total_trainset.train_labels[idx]) for idx in train_idx ] for idx in range(len(total_trainset.samples)): if total_trainset.samples[idx][1] not in idx120: continue total_trainset.samples[idx] = (total_trainset.samples[idx][0], idx120.index( total_trainset.samples[idx][1])) total_trainset = Subset(total_trainset, train_idx) total_trainset.train_labels = train_labels for idx in range(len(testset.samples)): if testset.samples[idx][1] not in idx120: continue testset.samples[idx] = (testset.samples[idx][0], idx120.index(testset.samples[idx][1])) testset = Subset(testset, test_idx) print('reduced_imagenet train=', len(total_trainset)) else: raise ValueError('invalid dataset name=%s' % dataset) if split > 0.0: sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0) sss = sss.split(list(range(len(total_trainset))), total_trainset.train_labels) for _ in range(split_idx + 1): train_idx, valid_idx = next(sss) train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetSampler(valid_idx) if horovod: import horovod.torch as hvd train_sampler = torch.utils.data.distributed.DistributedSampler( train_sampler, num_replicas=hvd.size(), rank=hvd.rank()) else: valid_sampler = SubsetSampler([]) if horovod: import horovod.torch as hvd train_sampler = DistributedStratifiedSampler( total_trainset.train_labels, num_replicas=hvd.size(), rank=hvd.rank()) else: train_sampler = StratifiedSampler(total_trainset.train_labels) trainloader = torch.utils.data.DataLoader( total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=32, pin_memory=True, sampler=train_sampler, drop_last=True) validloader = torch.utils.data.DataLoader(total_trainset, batch_size=batch, shuffle=False, num_workers=16, pin_memory=True, sampler=valid_sampler, drop_last=False) testloader = torch.utils.data.DataLoader(testset, batch_size=batch, shuffle=False, num_workers=32, pin_memory=True, drop_last=False) return train_sampler, trainloader, validloader, testloader
def load_dataset(data_dir, resize, dataset_name, img_type): if dataset_name == 'cifar_10': mean = cifar_10['mean'] std = cifar_10['std'] elif dataset_name == 'cifar_100': mean = cifar_100['mean'] std = cifar_100['std'] else: print( 'Dataset not recognized. Data normalize with equal mean/std weights' ) mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] hdf5_folder = '{}/hdf5'.format(data_dir) if os.path.exists(hdf5_folder): shutil.rmtree(hdf5_folder) create_hdf5(data_dir, resize, dataset_name, img_type) train_transform = transforms.Compose([ transforms.Pad(4), transforms.RandomCrop(resize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) test_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) if isinstance(C.get()['aug'], list): logger.debug('augmentation provided.') train_transform.transforms.insert(0, Augmentation(C.get()['aug'])) else: logger.debug('augmentation: %s' % C.get()['aug']) if C.get()['aug'] == 'random2048': train_transform.transforms.insert( 0, Augmentation(random_search2048())) elif C.get()['aug'] == 'fa_reduced_cifar10': train_transform.transforms.insert( 0, Augmentation(fa_reduced_cifar10())) elif C.get()['aug'] == 'fa_reduced_imagenet': train_transform.transforms.insert( 0, Augmentation(fa_reduced_imagenet())) elif C.get()['aug'] == 'arsaug': train_transform.transforms.insert(0, Augmentation(arsaug_policy())) elif C.get()['aug'] == 'autoaug_cifar10': train_transform.transforms.insert( 0, Augmentation(autoaug_paper_cifar10())) elif C.get()['aug'] == 'autoaug_extend': train_transform.transforms.insert(0, Augmentation(autoaug_policy())) elif C.get()['aug'] in ['default', 'inception', 'inception320']: pass else: raise ValueError('not found augmentations. %s' % C.get()['aug']) if C.get()['cutout'] > 0: train_transform.transforms.append(CutoutDefault(C.get()['cutout'])) hdf5_folder = '{}/hdf5'.format(data_dir) hdf5_train_path = '{}/{}_{}.hdf5'.format(hdf5_folder, dataset_name, 'training') hdf5_test_path = '{}/{}_{}.hdf5'.format(hdf5_folder, dataset_name, 'test') train_dataset = CustomDataset(hdf5_file=hdf5_train_path, transform=train_transform) val_dataset = CustomDataset(hdf5_file=hdf5_train_path, transform=test_transform) test_dataset = CustomDataset(hdf5_file=hdf5_test_path, transform=test_transform) train_dataset.train_labels = train_dataset.labels_id return [train_dataset, val_dataset, test_dataset]
def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode=False, target_lb=-1, gr_assign=None, gr_id=None, gr_ids=None, rand_val=False): if 'cifar' in dataset or 'svhn' in dataset: if "cifar" in dataset: _mean, _std = _CIFAR_MEAN, _CIFAR_STD else: _mean, _std = _SVHN_MEAN, _SVHN_STD transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(_mean, _std), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(_mean, _std), ]) elif 'imagenet' in dataset: input_size = 224 sized_size = 256 if 'efficientnet' in C.get()['model']['type']: input_size = EfficientNet.get_image_size(C.get()['model']['type']) sized_size = input_size + 32 # TODO # sized_size = int(round(input_size / 224. * 256)) # sized_size = input_size logger.info('size changed to %d/%d.' % (input_size, sized_size)) transform_train = transforms.Compose([ EfficientNetRandomCrop(input_size), transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), # transforms.RandomResizedCrop(input_size, scale=(0.1, 1.0), interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, ), transforms.ToTensor(), Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) transform_test = transforms.Compose([ EfficientNetCenterCrop(input_size), transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: raise ValueError('dataset=%s' % dataset) if isinstance(C.get()['aug'], list): logger.debug('augmentation provided.') transform_train.transforms.insert(0, Augmentation(C.get()['aug'])) elif isinstance(C.get()['aug'], dict): # group version logger.debug('group augmentation provided.') else: logger.debug('augmentation: %s' % C.get()['aug']) if C.get()['aug'] == 'fa_reduced_cifar10': transform_train.transforms.insert(0, Augmentation(fa_reduced_cifar10())) elif C.get()['aug'] == 'fa_reduced_imagenet': transform_train.transforms.insert(0, Augmentation(fa_resnet50_rimagenet())) elif C.get()['aug'] == 'fa_reduced_svhn': transform_train.transforms.insert(0, Augmentation(fa_reduced_svhn())) elif C.get()['aug'] == 'arsaug': transform_train.transforms.insert(0, Augmentation(arsaug_policy())) elif C.get()['aug'] == 'autoaug_cifar10': transform_train.transforms.insert(0, Augmentation(autoaug_paper_cifar10())) elif C.get()['aug'] == 'autoaug_extend': transform_train.transforms.insert(0, Augmentation(autoaug_policy())) elif C.get()['aug'] in ['default', "clean", "nonorm", "nocut"]: pass else: raise ValueError('not found augmentations. %s' % C.get()['aug']) if C.get()['cutout'] > 0 and C.get()['aug'] != "nocut": transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) if C.get()['aug'] == "clean": transform_train = transform_test elif C.get()['aug'] == "nonorm": transform_train = transforms.Compose([ transforms.ToTensor() ]) train_idx = valid_idx = None if dataset == 'cifar10': if isinstance(C.get()['aug'], dict): total_trainset = GrAugCIFAR10(root=dataroot, gr_assign=gr_assign, gr_policies=C.get()['aug'], train=True, download=False, transform=transform_train) else: total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=False, transform=transform_train) testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=False, transform=transform_test) elif dataset == 'reduced_cifar10': if isinstance(C.get()['aug'], dict): total_trainset = GrAugCIFAR10(root=dataroot, gr_assign=gr_assign, gr_policies=C.get()['aug'], train=True, download=False, transform=transform_train) else: total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=False, transform=transform_train) sss = StratifiedShuffleSplit(n_splits=5, train_size=4000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) for _ in range(split_idx+1): train_idx, valid_idx = next(sss) testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=False, transform=transform_test) elif dataset == 'cifar100': if isinstance(C.get()['aug'], dict): total_trainset = GrAugData("CIFAR100", root=dataroot, gr_assign=gr_assign, gr_policies=C.get()['aug'], train=True, download=False, transform=transform_train) else: total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=False, transform=transform_train) testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=False, transform=transform_test) elif dataset == 'svhn': #TODO trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=False, transform=transform_train) extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=False, transform=transform_train) total_trainset = ConcatDataset([trainset, extraset]) testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=False, transform=transform_test) elif dataset == 'reduced_svhn': if isinstance(C.get()['aug'], dict): total_trainset = GrAugData("SVHN", root=dataroot, gr_assign=gr_assign, gr_policies=C.get()['aug'], split='train', download=False, transform=transform_train) else: total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=False, transform=transform_train) sss = StratifiedShuffleSplit(n_splits=5, train_size=1000, test_size=7325, random_state=0) sss = sss.split(list(range(len(total_trainset))), total_trainset.labels) for _ in range(split_idx+1): train_idx, valid_idx = next(sss) # targets = [total_trainset.labels[idx] for idx in train_idx] # total_trainset = Subset(total_trainset, train_idx) # total_trainset.targets = targets testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=False, transform=transform_test) elif dataset == 'imagenet': total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) # compatibility total_trainset.targets = [lb for _, lb in total_trainset.samples] elif dataset == 'reduced_imagenet': # randomly chosen indices # idx120 = sorted(random.sample(list(range(1000)), k=120)) idx120 = [16, 23, 52, 57, 76, 93, 95, 96, 99, 121, 122, 128, 148, 172, 181, 189, 202, 210, 232, 238, 257, 258, 259, 277, 283, 289, 295, 304, 307, 318, 322, 331, 337, 338, 345, 350, 361, 375, 376, 381, 388, 399, 401, 408, 424, 431, 432, 440, 447, 462, 464, 472, 483, 497, 506, 512, 530, 541, 553, 554, 557, 564, 570, 584, 612, 614, 619, 626, 631, 632, 650, 657, 658, 660, 674, 675, 680, 682, 691, 695, 699, 711, 734, 736, 741, 754, 757, 764, 769, 770, 780, 781, 787, 797, 799, 811, 822, 829, 830, 835, 837, 842, 843, 845, 873, 883, 897, 900, 902, 905, 913, 920, 925, 937, 938, 940, 941, 944, 949, 959] total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) # compatibility total_trainset.targets = [lb for _, lb in total_trainset.samples] sss = StratifiedShuffleSplit(n_splits=1, test_size=len(total_trainset) - 50000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) train_idx, valid_idx = next(sss) # filter out train_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, train_idx)) valid_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, valid_idx)) test_idx = list(filter(lambda x: testset.samples[x][1] in idx120, range(len(testset)))) targets = [idx120.index(total_trainset.targets[idx]) for idx in train_idx] for idx in range(len(total_trainset.samples)): if total_trainset.samples[idx][1] not in idx120: continue total_trainset.samples[idx] = (total_trainset.samples[idx][0], idx120.index(total_trainset.samples[idx][1])) total_trainset = Subset(total_trainset, train_idx) total_trainset.targets = targets for idx in range(len(testset.samples)): if testset.samples[idx][1] not in idx120: continue testset.samples[idx] = (testset.samples[idx][0], idx120.index(testset.samples[idx][1])) testset = Subset(testset, test_idx) print('reduced_imagenet train=', len(total_trainset)) elif dataset == "cifar10_svhn": if isinstance(C.get()['aug'], dict): # last stage: benchmark test total_trainset = GrAugMix(dataset.split("_"), gr_assign=gr_assign, gr_policies=C.get()['aug'], root=dataroot, train=True, download=False, transform=transform_train, gr_ids=gr_ids) else: # eval_tta & childnet training total_trainset = GrAugMix(dataset.split("_"), root=dataroot, train=True, download=False, transform=transform_train) testset = GrAugMix(dataset.split("_"), root=dataroot, train=False, download=False, transform=transform_test) else: raise ValueError('invalid dataset name=%s' % dataset) if not hasattr(total_trainset, "gr_ids"): total_trainset.gr_ids = None if gr_ids is not None: total_trainset.gr_ids = gr_ids if gr_assign is not None and total_trainset.gr_ids is None: # eval_tta3 temp_trainset = copy.deepcopy(total_trainset) # temp_trainset.transform = transform_test # just normalize temp_loader = torch.utils.data.DataLoader( temp_trainset, batch_size=batch, shuffle=False, num_workers=4, drop_last=False) gr_dist = gr_assign(temp_loader) gr_ids = torch.max(gr_dist)[1].numpy() if split > 0.0: if train_idx is None or valid_idx is None: # filter by split ratio sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0) sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) for _ in range(split_idx + 1): train_idx, valid_idx = next(sss) if gr_id is not None: # filter by group idx2gr = total_trainset.gr_ids ps = PredefinedSplit(idx2gr) ps = ps.split() for _ in range(gr_id + 1): _, gr_split_idx = next(ps) train_idx = [idx for idx in train_idx if idx in gr_split_idx] valid_idx = [idx for idx in valid_idx if idx in gr_split_idx] if target_lb >= 0: train_idx = [i for i in train_idx if total_trainset.targets[i] == target_lb] valid_idx = [i for i in valid_idx if total_trainset.targets[i] == target_lb] train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetSampler(valid_idx) if not rand_val else SubsetRandomSampler(valid_idx) if multinode: train_sampler = torch.utils.data.distributed.DistributedSampler(Subset(total_trainset, train_idx), num_replicas=dist.get_world_size(), rank=dist.get_rank()) else: train_sampler = None valid_sampler = SubsetSampler([]) if gr_id is not None: # filter by group idx2gr = total_trainset.gr_ids ps = PredefinedSplit(idx2gr) ps = ps.split() for _ in range(gr_id + 1): _, gr_split_idx = next(ps) targets = [total_trainset.targets[idx] for idx in gr_split_idx] total_trainset = Subset(total_trainset, gr_split_idx) total_trainset.targets = targets if train_idx is not None and valid_idx is not None: if dataset in ["svhn", "reduced_svhn"]: targets = [total_trainset.labels[idx] for idx in train_idx] else: targets = [total_trainset.targets[idx] for idx in train_idx] total_trainset = Subset(total_trainset, train_idx) total_trainset.targets = targets if multinode: train_sampler = torch.utils.data.distributed.DistributedSampler(total_trainset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) logger.info(f'----- dataset with DistributedSampler {dist.get_rank()}/{dist.get_world_size()}') trainloader = torch.utils.data.DataLoader( total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=8 if torch.cuda.device_count()==8 else 4, pin_memory=True, sampler=train_sampler, drop_last=True) validloader = torch.utils.data.DataLoader( total_trainset, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, sampler=valid_sampler, drop_last=False if not rand_val else True) testloader = torch.utils.data.DataLoader( testset, batch_size=batch, shuffle=False, num_workers=8 if torch.cuda.device_count()==8 else 4, pin_memory=True, drop_last=False ) return train_sampler, trainloader, validloader, testloader
def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode=False, target_lb=-1): if 'cifar' in dataset or 'svhn' in dataset: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), ]) elif 'imagenet' in dataset: input_size = 224 sized_size = 256 if 'efficientnet' in C.get()['model']['type']: input_size = EfficientNet.get_image_size(C.get()['model']['type']) sized_size = input_size + 32 # TODO # sized_size = int(round(input_size / 224. * 256)) # sized_size = input_size logger.info('size changed to %d/%d.' % (input_size, sized_size)) transform_train = transforms.Compose([ EfficientNetRandomCrop(input_size), transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), # transforms.RandomResizedCrop(input_size, scale=(0.1, 1.0), interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, ), transforms.ToTensor(), Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) transform_test = transforms.Compose([ EfficientNetCenterCrop(input_size), transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) elif 'gta5' in dataset: transform_target_after = transforms.Compose([PILToLongTensor()]) transform_train_pre = Compose( [RandomCrop((321, 321)), RandomHorizontallyFlip()]) # weak transform transform_valid_pre = Compose([RandomCrop( (321, 321))]) # weak transform transform_train = transforms.Compose([transforms.ToTensor()]) transform_test_pre = None # Compose([RandomCrop((321, 321))]) transform_test = transforms.Compose([transforms.ToTensor()]) else: raise ValueError('dataset=%s' % dataset) total_aug = augs = None if isinstance(C.get()['aug'], list): logger.debug('augmentation provided.') transform_train.transforms.insert(0, Augmentation(C.get()['aug'])) else: logger.debug('augmentation: %s' % C.get()['aug']) if C.get()['aug'] == 'fa_reduced_cifar10': transform_train.transforms.insert( 0, Augmentation(fa_reduced_cifar10())) elif C.get()['aug'] == 'fa_reduced_imagenet': transform_train.transforms.insert( 0, Augmentation(fa_resnet50_rimagenet())) elif C.get()['aug'] == 'fa_reduced_svhn': transform_train.transforms.insert(0, Augmentation(fa_reduced_svhn())) elif C.get()['aug'] == 'arsaug': transform_train.transforms.insert(0, Augmentation(arsaug_policy())) elif C.get()['aug'] == 'autoaug_cifar10': transform_train.transforms.insert( 0, Augmentation(autoaug_paper_cifar10())) elif C.get()['aug'] == 'autoaug_extend': transform_train.transforms.insert(0, Augmentation(autoaug_policy())) elif C.get()['aug'] in ['default']: pass else: raise ValueError('not found augmentations. %s' % C.get()['aug']) if C.get()['cutout'] > 0: transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) if dataset == 'cifar10': total_trainset = torchvision.datasets.CIFAR10( root=dataroot, train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'reduced_cifar10': total_trainset = torchvision.datasets.CIFAR10( root=dataroot, train=True, download=True, transform=transform_train) sss = StratifiedShuffleSplit(n_splits=1, test_size=46000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) train_idx, valid_idx = next(sss) targets = [total_trainset.targets[idx] for idx in train_idx] total_trainset = Subset(total_trainset, train_idx) total_trainset.targets = targets testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'cifar100': total_trainset = torchvision.datasets.CIFAR100( root=dataroot, train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test) elif dataset == 'svhn': trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_train) total_trainset = ConcatDataset([trainset, extraset]) testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) elif dataset == 'reduced_svhn': total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) sss = StratifiedShuffleSplit(n_splits=1, test_size=73257 - 1000, random_state=0) # 1000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) train_idx, valid_idx = next(sss) targets = [total_trainset.targets[idx] for idx in train_idx] total_trainset = Subset(total_trainset, train_idx) total_trainset.targets = targets testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) elif dataset == 'imagenet': total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train, download=True) testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) # compatibility total_trainset.targets = [lb for _, lb in total_trainset.samples] elif dataset == 'reduced_imagenet': # randomly chosen indices # idx120 = sorted(random.sample(list(range(1000)), k=120)) idx120 = [ 16, 23, 52, 57, 76, 93, 95, 96, 99, 121, 122, 128, 148, 172, 181, 189, 202, 210, 232, 238, 257, 258, 259, 277, 283, 289, 295, 304, 307, 318, 322, 331, 337, 338, 345, 350, 361, 375, 376, 381, 388, 399, 401, 408, 424, 431, 432, 440, 447, 462, 464, 472, 483, 497, 506, 512, 530, 541, 553, 554, 557, 564, 570, 584, 612, 614, 619, 626, 631, 632, 650, 657, 658, 660, 674, 675, 680, 682, 691, 695, 699, 711, 734, 736, 741, 754, 757, 764, 769, 770, 780, 781, 787, 797, 799, 811, 822, 829, 830, 835, 837, 842, 843, 845, 873, 883, 897, 900, 902, 905, 913, 920, 925, 937, 938, 940, 941, 944, 949, 959 ] total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) # compatibility total_trainset.targets = [lb for _, lb in total_trainset.samples] sss = StratifiedShuffleSplit(n_splits=1, test_size=len(total_trainset) - 50000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) train_idx, valid_idx = next(sss) # filter out train_idx = list( filter(lambda x: total_trainset.labels[x] in idx120, train_idx)) valid_idx = list( filter(lambda x: total_trainset.labels[x] in idx120, valid_idx)) test_idx = list( filter(lambda x: testset.samples[x][1] in idx120, range(len(testset)))) targets = [ idx120.index(total_trainset.targets[idx]) for idx in train_idx ] for idx in range(len(total_trainset.samples)): if total_trainset.samples[idx][1] not in idx120: continue total_trainset.samples[idx] = (total_trainset.samples[idx][0], idx120.index( total_trainset.samples[idx][1])) total_trainset = Subset(total_trainset, train_idx) total_trainset.targets = targets for idx in range(len(testset.samples)): if testset.samples[idx][1] not in idx120: continue testset.samples[idx] = (testset.samples[idx][0], idx120.index(testset.samples[idx][1])) testset = Subset(testset, test_idx) print('reduced_imagenet train=', len(total_trainset)) elif dataset == 'gta5': total_trainset = GTA5_Dataset( data_root_path=dataroot, split='train', transform_pre=transform_train_pre, transform_target_after=transform_target_after, transform_after=transform_train, sample=10000) total_validset = GTA5_Dataset( data_root_path=dataroot, split='valid', transform_pre=transform_valid_pre, transform_target_after=transform_target_after, transform_after=transform_test, sample=5000, seed=0) testset = GTA5_Dataset(data_root_path=dataroot, split='valid', transform_pre=transform_test_pre, transform_target_after=transform_target_after, transform_after=transform_test) elif dataset == 'reduced_gta5': total_trainset = GTA5_Dataset( data_root_path=dataroot, split='train', transform_pre=transform_train_pre, transform_target_after=transform_target_after, transform_after=transform_train, sample=1000, seed=0) total_validset = GTA5_Dataset( data_root_path=dataroot, split='valid', transform_pre=transform_valid_pre, transform_target_after=transform_target_after, transform_after=transform_test, sample=1000, seed=0) testset = GTA5_Dataset(data_root_path=dataroot, split='valid', transform_pre=transform_test_pre, transform_target_after=transform_target_after, transform_after=transform_test) else: raise ValueError('invalid dataset name=%s' % dataset) if total_aug is not None and augs is not None: total_trainset.set_preaug(augs, total_aug) print('set_preaug-') train_sampler = None if split > 0.0: if 'gta' not in dataset: sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0) sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) for _ in range(split_idx + 1): train_idx, valid_idx = next(sss) if target_lb >= 0: train_idx = [ i for i in train_idx if total_trainset.targets[i] == target_lb ] valid_idx = [ i for i in valid_idx if total_trainset.targets[i] == target_lb ] train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetSampler(valid_idx) if multinode: train_sampler = torch.utils.data.distributed.DistributedSampler( Subset(total_trainset, train_idx), num_replicas=dist.get_world_size(), rank=dist.get_rank()) else: train_sampler = None valid_sampler = None else: valid_sampler = SubsetSampler([]) if multinode: train_sampler = torch.utils.data.distributed.DistributedSampler( total_trainset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) logger.info( f'----- dataset with DistributedSampler {dist.get_rank()}/{dist.get_world_size()}' ) trainloader = torch.utils.data.DataLoader( total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=8, pin_memory=True, sampler=train_sampler, drop_last=True) validloader = torch.utils.data.DataLoader(total_trainset, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, sampler=valid_sampler, drop_last=False) testloader = torch.utils.data.DataLoader(testset, batch_size=batch, shuffle=False, num_workers=8, pin_memory=True, drop_last=False) return train_sampler, trainloader, validloader, testloader
def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, horovod=False, target_lb=-1): # torchvision 0.2(sh36 r0.3.0): train_labels # torchvision 0.4(local): targets # torchvision 0.4.1(sh36 r0.3.2): (not have attr '__version__')targets using_attr_train_labels = False try: torchvision_version = torchvision.__version__ if torchvision_version < '0.4': using_attr_train_labels = True except AttributeError: pass if 'cifar' in dataset or 'svhn' in dataset: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), ]) elif 'imagenet' in dataset: transform_train = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.08, 1.0), interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, ), transforms.ToTensor(), Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) transform_test = transforms.Compose([ transforms.Resize(256, interpolation=Image.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: raise ValueError('dataset=%s' % dataset) total_aug = augs = None if isinstance(C.get()['aug'], list): logger.debug('augmentation provided.') transform_train.transforms.insert(0, Augmentation(C.get()['aug'])) else: logger.debug('augmentation: %s' % C.get()['aug']) if C.get()['aug'] == 'fa_reduced_cifar10': transform_train.transforms.insert( 0, Augmentation(fa_reduced_cifar10())) elif C.get()['aug'] == 'fa_reduced_imagenet': transform_train.transforms.insert( 0, Augmentation(fa_resnet50_rimagenet())) elif C.get()['aug'] == 'fa_reduced_svhn': transform_train.transforms.insert(0, Augmentation(fa_reduced_svhn())) elif C.get()['aug'] == 'arsaug': transform_train.transforms.insert(0, Augmentation(arsaug_policy())) elif C.get()['aug'] == 'autoaug_cifar10': transform_train.transforms.insert( 0, Augmentation(autoaug_paper_cifar10())) elif C.get()['aug'] == 'autoaug_extend': transform_train.transforms.insert(0, Augmentation(autoaug_policy())) elif C.get()['aug'] in ['default', 'inception', 'inception320']: pass else: raise ValueError('not found augmentations. %s' % C.get()['aug']) if C.get()['cutout'] > 0: transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) if dataset == 'cifar10': total_trainset = torchvision.datasets.CIFAR10( root=dataroot, train=True, download=False, transform=transform_train) testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=False, transform=transform_test) elif dataset == 'reduced_cifar10': total_trainset = torchvision.datasets.CIFAR10( root=dataroot, train=True, download=False, transform=transform_train) sss = StratifiedShuffleSplit(n_splits=1, test_size=46000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) train_idx, valid_idx = next(sss) targets = [total_trainset.targets[idx] for idx in train_idx] total_trainset = Subset(total_trainset, train_idx) total_trainset.targets = targets testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=False, transform=transform_test) elif dataset == 'cifar100': total_trainset = torchvision.datasets.CIFAR100( root=dataroot, train=True, download=False, transform=transform_train) testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=False, transform=transform_test) elif dataset == 'svhn': trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=False, transform=transform_train) extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=False, transform=transform_train) total_trainset = ConcatDataset([trainset, extraset]) testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=False, transform=transform_test) elif dataset == 'reduced_svhn': total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=False, transform=transform_train) sss = StratifiedShuffleSplit(n_splits=1, test_size=73257 - 1000, random_state=0) # 1000 trainset sss = sss.split( list(range(len(total_trainset))), total_trainset.train_labels if using_attr_train_labels else total_trainset.targets) train_idx, valid_idx = next(sss) targets = [total_trainset.targets[idx] for idx in train_idx] total_trainset = Subset(total_trainset, train_idx) total_trainset.targets = targets testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=False, transform=transform_test) elif dataset == 'imagenet': total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) # compatibility total_trainset.targets = [lb for _, lb in total_trainset.samples] elif dataset == 'reduced_imagenet': # randomly chosen indices idx120 = [ 904, 385, 759, 884, 784, 844, 132, 214, 990, 786, 979, 582, 104, 288, 697, 480, 66, 943, 308, 282, 118, 926, 882, 478, 133, 884, 570, 964, 825, 656, 661, 289, 385, 448, 705, 609, 955, 5, 703, 713, 695, 811, 958, 147, 6, 3, 59, 354, 315, 514, 741, 525, 685, 673, 657, 267, 575, 501, 30, 455, 905, 860, 355, 911, 24, 708, 346, 195, 660, 528, 330, 511, 439, 150, 988, 940, 236, 803, 741, 295, 111, 520, 856, 248, 203, 147, 625, 589, 708, 201, 712, 630, 630, 367, 273, 931, 960, 274, 112, 239, 463, 355, 955, 525, 404, 59, 981, 725, 90, 782, 604, 323, 418, 35, 95, 97, 193, 690, 869, 172 ] total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) # compatibility total_trainset.targets = [lb for _, lb in total_trainset.samples] sss = StratifiedShuffleSplit(n_splits=1, test_size=len(total_trainset) - 500000, random_state=0) # 4000 trainset sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) train_idx, valid_idx = next(sss) # filter out train_idx = list( filter(lambda x: total_trainset.labels[x] in idx120, train_idx)) valid_idx = list( filter(lambda x: total_trainset.labels[x] in idx120, valid_idx)) test_idx = list( filter(lambda x: testset.samples[x][1] in idx120, range(len(testset)))) targets = [ idx120.index(total_trainset.targets[idx]) for idx in train_idx ] for idx in range(len(total_trainset.samples)): if total_trainset.samples[idx][1] not in idx120: continue total_trainset.samples[idx] = (total_trainset.samples[idx][0], idx120.index( total_trainset.samples[idx][1])) total_trainset = Subset(total_trainset, train_idx) total_trainset.targets = targets for idx in range(len(testset.samples)): if testset.samples[idx][1] not in idx120: continue testset.samples[idx] = (testset.samples[idx][0], idx120.index(testset.samples[idx][1])) testset = Subset(testset, test_idx) print('reduced_imagenet train=', len(total_trainset)) else: raise ValueError('invalid dataset name=%s' % dataset) if total_aug is not None and augs is not None: total_trainset.set_preaug(augs, total_aug) print('set_preaug-') train_sampler = None if split > 0.0: sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0) sss = sss.split( list(range(len(total_trainset))), total_trainset.train_labels if using_attr_train_labels else total_trainset.targets) for _ in range(split_idx + 1): train_idx, valid_idx = next(sss) if target_lb >= 0: train_idx = [ i for i in train_idx if total_trainset.targets[i] == target_lb ] valid_idx = [ i for i in valid_idx if total_trainset.targets[i] == target_lb ] train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetSampler(valid_idx) if horovod: import horovod.torch as hvd train_sampler = torch.utils.data.distributed.DistributedSampler( train_sampler, num_replicas=hvd.size(), rank=hvd.rank()) else: valid_sampler = SubsetSampler([]) if horovod: import horovod.torch as hvd train_sampler = torch.utils.data.distributed.DistributedSampler( valid_sampler, num_replicas=hvd.size(), rank=hvd.rank()) trainloader = torch.utils.data.DataLoader( total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=32, pin_memory=True, sampler=train_sampler, drop_last=True) validloader = torch.utils.data.DataLoader(total_trainset, batch_size=batch, shuffle=False, num_workers=16, pin_memory=True, sampler=valid_sampler, drop_last=False) testloader = torch.utils.data.DataLoader(testset, batch_size=batch, shuffle=False, num_workers=32, pin_memory=True, drop_last=False) return train_sampler, trainloader, validloader, testloader