def __init__(self):
     # image augmentation functions
     self.flip_lr = transforms.RandomHorizontalFlip(p=0.5)
     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
     self.train_transform = transforms.Compose([
         transforms.RandomResizedCrop(128, scale=(0.2, 0.9), ratio=(0.7, 1.4),
                                      interpolation=3),
         transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
         transforms.RandomGrayscale(p=0.25),
         transforms.ToTensor(),
         Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
         normalize
     ])
     self.test_transform = transforms.Compose([
         transforms.Resize(146, interpolation=3),
         transforms.CenterCrop(128),
         transforms.ToTensor(),
         normalize
     ])
     self.train_transform.transforms.insert(0, Augmentation(fa_resnet50_rimagenet()))
Example #2
0
def get_dataloaders(dataset,
                    batch,
                    dataroot,
                    split=0.15,
                    split_idx=0,
                    multinode=False,
                    target_lb=-1):
    if 'cifar' in dataset or 'svhn' in dataset:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
    elif 'imagenet' in dataset:
        input_size = 224
        sized_size = 256

        if 'efficientnet' in C.get()['model']['type']:
            input_size = EfficientNet.get_image_size(C.get()['model']['type'])
            sized_size = input_size + 32  # TODO
            # sized_size = int(round(input_size / 224. * 256))
            # sized_size = input_size
            logger.info('size changed to %d/%d.' % (input_size, sized_size))

        transform_train = transforms.Compose([
            EfficientNetRandomCrop(input_size),
            transforms.Resize((input_size, input_size),
                              interpolation=Image.BICUBIC),
            # transforms.RandomResizedCrop(input_size, scale=(0.1, 1.0), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
            ),
            transforms.ToTensor(),
            Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        transform_test = transforms.Compose([
            EfficientNetCenterCrop(input_size),
            transforms.Resize((input_size, input_size),
                              interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    elif 'gta5' in dataset:
        transform_target_after = transforms.Compose([PILToLongTensor()])

        transform_train_pre = Compose(
            [RandomCrop((321, 321)),
             RandomHorizontallyFlip()])  # weak transform
        transform_valid_pre = Compose([RandomCrop(
            (321, 321))])  # weak transform
        transform_train = transforms.Compose([transforms.ToTensor()])
        transform_test_pre = None  # Compose([RandomCrop((321, 321))])
        transform_test = transforms.Compose([transforms.ToTensor()])
    else:
        raise ValueError('dataset=%s' % dataset)

    total_aug = augs = None
    if isinstance(C.get()['aug'], list):
        logger.debug('augmentation provided.')
        transform_train.transforms.insert(0, Augmentation(C.get()['aug']))
    else:
        logger.debug('augmentation: %s' % C.get()['aug'])
        if C.get()['aug'] == 'fa_reduced_cifar10':
            transform_train.transforms.insert(
                0, Augmentation(fa_reduced_cifar10()))

        elif C.get()['aug'] == 'fa_reduced_imagenet':
            transform_train.transforms.insert(
                0, Augmentation(fa_resnet50_rimagenet()))

        elif C.get()['aug'] == 'fa_reduced_svhn':
            transform_train.transforms.insert(0,
                                              Augmentation(fa_reduced_svhn()))

        elif C.get()['aug'] == 'arsaug':
            transform_train.transforms.insert(0, Augmentation(arsaug_policy()))
        elif C.get()['aug'] == 'autoaug_cifar10':
            transform_train.transforms.insert(
                0, Augmentation(autoaug_paper_cifar10()))
        elif C.get()['aug'] == 'autoaug_extend':
            transform_train.transforms.insert(0,
                                              Augmentation(autoaug_policy()))
        elif C.get()['aug'] in ['default']:
            pass
        else:
            raise ValueError('not found augmentations. %s' % C.get()['aug'])

    if C.get()['cutout'] > 0:
        transform_train.transforms.append(CutoutDefault(C.get()['cutout']))

    if dataset == 'cifar10':
        total_trainset = torchvision.datasets.CIFAR10(
            root=dataroot,
            train=True,
            download=True,
            transform=transform_train)
        testset = torchvision.datasets.CIFAR10(root=dataroot,
                                               train=False,
                                               download=True,
                                               transform=transform_test)
    elif dataset == 'reduced_cifar10':
        total_trainset = torchvision.datasets.CIFAR10(
            root=dataroot,
            train=True,
            download=True,
            transform=transform_train)
        sss = StratifiedShuffleSplit(n_splits=1,
                                     test_size=46000,
                                     random_state=0)  # 4000 trainset
        sss = sss.split(list(range(len(total_trainset))),
                        total_trainset.targets)
        train_idx, valid_idx = next(sss)
        targets = [total_trainset.targets[idx] for idx in train_idx]
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets
        testset = torchvision.datasets.CIFAR10(root=dataroot,
                                               train=False,
                                               download=True,
                                               transform=transform_test)

    elif dataset == 'cifar100':
        total_trainset = torchvision.datasets.CIFAR100(
            root=dataroot,
            train=True,
            download=True,
            transform=transform_train)
        testset = torchvision.datasets.CIFAR100(root=dataroot,
                                                train=False,
                                                download=True,
                                                transform=transform_test)
    elif dataset == 'svhn':
        trainset = torchvision.datasets.SVHN(root=dataroot,
                                             split='train',
                                             download=True,
                                             transform=transform_train)
        extraset = torchvision.datasets.SVHN(root=dataroot,
                                             split='extra',
                                             download=True,
                                             transform=transform_train)
        total_trainset = ConcatDataset([trainset, extraset])
        testset = torchvision.datasets.SVHN(root=dataroot,
                                            split='test',
                                            download=True,
                                            transform=transform_test)
    elif dataset == 'reduced_svhn':
        total_trainset = torchvision.datasets.SVHN(root=dataroot,
                                                   split='train',
                                                   download=True,
                                                   transform=transform_train)
        sss = StratifiedShuffleSplit(n_splits=1,
                                     test_size=73257 - 1000,
                                     random_state=0)  # 1000 trainset
        sss = sss.split(list(range(len(total_trainset))),
                        total_trainset.targets)
        train_idx, valid_idx = next(sss)
        targets = [total_trainset.targets[idx] for idx in train_idx]
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        testset = torchvision.datasets.SVHN(root=dataroot,
                                            split='test',
                                            download=True,
                                            transform=transform_test)
    elif dataset == 'imagenet':
        total_trainset = ImageNet(root=os.path.join(dataroot,
                                                    'imagenet-pytorch'),
                                  transform=transform_train,
                                  download=True)
        testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'),
                           split='val',
                           transform=transform_test)

        # compatibility
        total_trainset.targets = [lb for _, lb in total_trainset.samples]
    elif dataset == 'reduced_imagenet':
        # randomly chosen indices
        #         idx120 = sorted(random.sample(list(range(1000)), k=120))
        idx120 = [
            16, 23, 52, 57, 76, 93, 95, 96, 99, 121, 122, 128, 148, 172, 181,
            189, 202, 210, 232, 238, 257, 258, 259, 277, 283, 289, 295, 304,
            307, 318, 322, 331, 337, 338, 345, 350, 361, 375, 376, 381, 388,
            399, 401, 408, 424, 431, 432, 440, 447, 462, 464, 472, 483, 497,
            506, 512, 530, 541, 553, 554, 557, 564, 570, 584, 612, 614, 619,
            626, 631, 632, 650, 657, 658, 660, 674, 675, 680, 682, 691, 695,
            699, 711, 734, 736, 741, 754, 757, 764, 769, 770, 780, 781, 787,
            797, 799, 811, 822, 829, 830, 835, 837, 842, 843, 845, 873, 883,
            897, 900, 902, 905, 913, 920, 925, 937, 938, 940, 941, 944, 949,
            959
        ]
        total_trainset = ImageNet(root=os.path.join(dataroot,
                                                    'imagenet-pytorch'),
                                  transform=transform_train)
        testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'),
                           split='val',
                           transform=transform_test)

        # compatibility
        total_trainset.targets = [lb for _, lb in total_trainset.samples]

        sss = StratifiedShuffleSplit(n_splits=1,
                                     test_size=len(total_trainset) - 50000,
                                     random_state=0)  # 4000 trainset
        sss = sss.split(list(range(len(total_trainset))),
                        total_trainset.targets)
        train_idx, valid_idx = next(sss)

        # filter out
        train_idx = list(
            filter(lambda x: total_trainset.labels[x] in idx120, train_idx))
        valid_idx = list(
            filter(lambda x: total_trainset.labels[x] in idx120, valid_idx))
        test_idx = list(
            filter(lambda x: testset.samples[x][1] in idx120,
                   range(len(testset))))

        targets = [
            idx120.index(total_trainset.targets[idx]) for idx in train_idx
        ]
        for idx in range(len(total_trainset.samples)):
            if total_trainset.samples[idx][1] not in idx120:
                continue
            total_trainset.samples[idx] = (total_trainset.samples[idx][0],
                                           idx120.index(
                                               total_trainset.samples[idx][1]))
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        for idx in range(len(testset.samples)):
            if testset.samples[idx][1] not in idx120:
                continue
            testset.samples[idx] = (testset.samples[idx][0],
                                    idx120.index(testset.samples[idx][1]))
        testset = Subset(testset, test_idx)
        print('reduced_imagenet train=', len(total_trainset))
    elif dataset == 'gta5':
        total_trainset = GTA5_Dataset(
            data_root_path=dataroot,
            split='train',
            transform_pre=transform_train_pre,
            transform_target_after=transform_target_after,
            transform_after=transform_train,
            sample=10000)
        total_validset = GTA5_Dataset(
            data_root_path=dataroot,
            split='valid',
            transform_pre=transform_valid_pre,
            transform_target_after=transform_target_after,
            transform_after=transform_test,
            sample=5000,
            seed=0)
        testset = GTA5_Dataset(data_root_path=dataroot,
                               split='valid',
                               transform_pre=transform_test_pre,
                               transform_target_after=transform_target_after,
                               transform_after=transform_test)
    elif dataset == 'reduced_gta5':
        total_trainset = GTA5_Dataset(
            data_root_path=dataroot,
            split='train',
            transform_pre=transform_train_pre,
            transform_target_after=transform_target_after,
            transform_after=transform_train,
            sample=1000,
            seed=0)
        total_validset = GTA5_Dataset(
            data_root_path=dataroot,
            split='valid',
            transform_pre=transform_valid_pre,
            transform_target_after=transform_target_after,
            transform_after=transform_test,
            sample=1000,
            seed=0)
        testset = GTA5_Dataset(data_root_path=dataroot,
                               split='valid',
                               transform_pre=transform_test_pre,
                               transform_target_after=transform_target_after,
                               transform_after=transform_test)
    else:
        raise ValueError('invalid dataset name=%s' % dataset)

    if total_aug is not None and augs is not None:
        total_trainset.set_preaug(augs, total_aug)
        print('set_preaug-')

    train_sampler = None
    if split > 0.0:
        if 'gta' not in dataset:
            sss = StratifiedShuffleSplit(n_splits=5,
                                         test_size=split,
                                         random_state=0)
            sss = sss.split(list(range(len(total_trainset))),
                            total_trainset.targets)
            for _ in range(split_idx + 1):
                train_idx, valid_idx = next(sss)

            if target_lb >= 0:
                train_idx = [
                    i for i in train_idx
                    if total_trainset.targets[i] == target_lb
                ]
                valid_idx = [
                    i for i in valid_idx
                    if total_trainset.targets[i] == target_lb
                ]

            train_sampler = SubsetRandomSampler(train_idx)
            valid_sampler = SubsetSampler(valid_idx)

            if multinode:
                train_sampler = torch.utils.data.distributed.DistributedSampler(
                    Subset(total_trainset, train_idx),
                    num_replicas=dist.get_world_size(),
                    rank=dist.get_rank())
        else:
            train_sampler = None
            valid_sampler = None
    else:
        valid_sampler = SubsetSampler([])
        if multinode:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                total_trainset,
                num_replicas=dist.get_world_size(),
                rank=dist.get_rank())
            logger.info(
                f'----- dataset with DistributedSampler  {dist.get_rank()}/{dist.get_world_size()}'
            )

    trainloader = torch.utils.data.DataLoader(
        total_trainset,
        batch_size=batch,
        shuffle=True if train_sampler is None else False,
        num_workers=8,
        pin_memory=True,
        sampler=train_sampler,
        drop_last=True)
    validloader = torch.utils.data.DataLoader(total_trainset,
                                              batch_size=batch,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True,
                                              sampler=valid_sampler,
                                              drop_last=False)

    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=batch,
                                             shuffle=False,
                                             num_workers=8,
                                             pin_memory=True,
                                             drop_last=False)
    return train_sampler, trainloader, validloader, testloader
Example #3
0
def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, horovod=False, target_lb=-1):
    if 'cifar' in dataset or 'svhn' in dataset:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
    elif 'imagenet' in dataset:
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.08, 1.0), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
            ),
            transforms.ToTensor(),
            Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        transform_test = transforms.Compose([
            transforms.Resize(256, interpolation=Image.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        raise ValueError('dataset=%s' % dataset)

    total_aug = augs = None
    if isinstance(C.get()['aug'], list):
        logger.debug('augmentation provided.')
        transform_train.transforms.insert(0, Augmentation(C.get()['aug']))
    else:
        logger.debug('augmentation: %s' % C.get()['aug'])
        if C.get()['aug'] == 'fa_reduced_cifar10':
            transform_train.transforms.insert(0, Augmentation(fa_reduced_cifar10()))

        elif C.get()['aug'] == 'fa_reduced_imagenet':
            transform_train.transforms.insert(0, Augmentation(fa_resnet50_rimagenet()))

        elif C.get()['aug'] == 'fa_reduced_svhn':
            transform_train.transforms.insert(0, Augmentation(fa_reduced_svhn()))

        elif C.get()['aug'] == 'arsaug':
            transform_train.transforms.insert(0, Augmentation(arsaug_policy()))
        elif C.get()['aug'] == 'autoaug_cifar10':
            transform_train.transforms.insert(0, Augmentation(autoaug_paper_cifar10()))
        elif C.get()['aug'] == 'autoaug_extend':
            transform_train.transforms.insert(0, Augmentation(autoaug_policy()))
        elif C.get()['aug'] in ['default', 'inception', 'inception320']:
            pass
        else:
            raise ValueError('not found augmentations. %s' % C.get()['aug'])

    if C.get()['cutout'] > 0:
        transform_train.transforms.append(CutoutDefault(C.get()['cutout']))

    if dataset == 'cifar10':
        total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train)
        testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test)
    elif dataset == 'reduced_cifar10':
        total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train)
        sss = StratifiedShuffleSplit(n_splits=1, test_size=46000, random_state=0)   # 4000 trainset
        sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
        train_idx, valid_idx = next(sss)
        targets = [total_trainset.targets[idx] for idx in train_idx]
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test)
    elif dataset == 'cifar100':
        total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=True, transform=transform_train)
        testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test)
    elif dataset == 'svhn':
        trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train)
        extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_train)
        total_trainset = ConcatDataset([trainset, extraset])
        testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test)
    elif dataset == 'reduced_svhn':
        total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train)
        sss = StratifiedShuffleSplit(n_splits=1, test_size=73257-1000, random_state=0)  # 1000 trainset
        sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
        train_idx, valid_idx = next(sss)
        targets = [total_trainset.targets[idx] for idx in train_idx]
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test)
    elif dataset == 'imagenet':
        total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train)
        testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)

        # compatibility
        total_trainset.targets = [lb for _, lb in total_trainset.samples]
    elif dataset == 'reduced_imagenet':
        # randomly chosen indices
        idx120 = [904, 385, 759, 884, 784, 844, 132, 214, 990, 786, 979, 582, 104, 288, 697, 480, 66, 943, 308, 282, 118, 926, 882, 478, 133, 884, 570, 964, 825, 656, 661, 289, 385, 448, 705, 609, 955, 5, 703, 713, 695, 811, 958, 147, 6, 3, 59, 354, 315, 514, 741, 525, 685, 673, 657, 267, 575, 501, 30, 455, 905, 860, 355, 911, 24, 708, 346, 195, 660, 528, 330, 511, 439, 150, 988, 940, 236, 803, 741, 295, 111, 520, 856, 248, 203, 147, 625, 589, 708, 201, 712, 630, 630, 367, 273, 931, 960, 274, 112, 239, 463, 355, 955, 525, 404, 59, 981, 725, 90, 782, 604, 323, 418, 35, 95, 97, 193, 690, 869, 172]
        total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train)
        testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)

        # compatibility
        total_trainset.targets = [lb for _, lb in total_trainset.samples]

        sss = StratifiedShuffleSplit(n_splits=1, test_size=len(total_trainset) - 6000, random_state=0)  # 4000 trainset
        sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
        train_idx, valid_idx = next(sss)

        # filter out
        train_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, train_idx))
        valid_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, valid_idx))
        test_idx = list(filter(lambda x: testset.samples[x][1] in idx120, range(len(testset))))

        targets = [idx120.index(total_trainset.targets[idx]) for idx in train_idx]
        for idx in range(len(total_trainset.samples)):
            if total_trainset.samples[idx][1] not in idx120:
                continue
            total_trainset.samples[idx] = (total_trainset.samples[idx][0], idx120.index(total_trainset.samples[idx][1]))
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        for idx in range(len(testset.samples)):
            if testset.samples[idx][1] not in idx120:
                continue
            testset.samples[idx] = (testset.samples[idx][0], idx120.index(testset.samples[idx][1]))
        testset = Subset(testset, test_idx)
        print('reduced_imagenet train=', len(total_trainset))
    else:
        raise ValueError('invalid dataset name=%s' % dataset)

    if total_aug is not None and augs is not None:
        total_trainset.set_preaug(augs, total_aug)
        print('set_preaug-')

    train_sampler = None
    if split > 0.0:
        sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0)
        sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
        for _ in range(split_idx + 1):
            train_idx, valid_idx = next(sss)

        if target_lb >= 0:
            train_idx = [i for i in train_idx if total_trainset.targets[i] == target_lb]
            valid_idx = [i for i in valid_idx if total_trainset.targets[i] == target_lb]

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetSampler(valid_idx)

        if horovod:
            import horovod.torch as hvd
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_sampler, num_replicas=hvd.size(), rank=hvd.rank())
    else:
        valid_sampler = SubsetSampler([])

        if horovod:
            import horovod.torch as hvd
            train_sampler = torch.utils.data.distributed.DistributedSampler(valid_sampler, num_replicas=hvd.size(), rank=hvd.rank())

    trainloader = torch.utils.data.DataLoader(
        total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=32, pin_memory=True,
        sampler=train_sampler, drop_last=True)
    validloader = torch.utils.data.DataLoader(
        total_trainset, batch_size=batch, shuffle=False, num_workers=16, pin_memory=True,
        sampler=valid_sampler, drop_last=False)

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch, shuffle=False, num_workers=32, pin_memory=True,
        drop_last=False
    )
    return train_sampler, trainloader, validloader, testloader
Example #4
0
def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode=False, target_lb=-1, gr_assign=None, gr_id=None, gr_ids=None, rand_val=False):
    if 'cifar' in dataset or 'svhn' in dataset:
        if "cifar" in dataset:
            _mean, _std = _CIFAR_MEAN, _CIFAR_STD
        else:
            _mean, _std = _SVHN_MEAN, _SVHN_STD
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(_mean, _std),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_mean, _std),
        ])
    elif 'imagenet' in dataset:
        input_size = 224
        sized_size = 256

        if 'efficientnet' in C.get()['model']['type']:
            input_size = EfficientNet.get_image_size(C.get()['model']['type'])
            sized_size = input_size + 32    # TODO
            # sized_size = int(round(input_size / 224. * 256))
            # sized_size = input_size
            logger.info('size changed to %d/%d.' % (input_size, sized_size))

        transform_train = transforms.Compose([
            EfficientNetRandomCrop(input_size),
            transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC),
            # transforms.RandomResizedCrop(input_size, scale=(0.1, 1.0), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
            ),
            transforms.ToTensor(),
            Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        transform_test = transforms.Compose([
            EfficientNetCenterCrop(input_size),
            transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    else:
        raise ValueError('dataset=%s' % dataset)

    if isinstance(C.get()['aug'], list):
        logger.debug('augmentation provided.')
        transform_train.transforms.insert(0, Augmentation(C.get()['aug']))
    elif isinstance(C.get()['aug'], dict):
        # group version
        logger.debug('group augmentation provided.')
    else:
        logger.debug('augmentation: %s' % C.get()['aug'])
        if C.get()['aug'] == 'fa_reduced_cifar10':
            transform_train.transforms.insert(0, Augmentation(fa_reduced_cifar10()))

        elif C.get()['aug'] == 'fa_reduced_imagenet':
            transform_train.transforms.insert(0, Augmentation(fa_resnet50_rimagenet()))

        elif C.get()['aug'] == 'fa_reduced_svhn':
            transform_train.transforms.insert(0, Augmentation(fa_reduced_svhn()))

        elif C.get()['aug'] == 'arsaug':
            transform_train.transforms.insert(0, Augmentation(arsaug_policy()))
        elif C.get()['aug'] == 'autoaug_cifar10':
            transform_train.transforms.insert(0, Augmentation(autoaug_paper_cifar10()))
        elif C.get()['aug'] == 'autoaug_extend':
            transform_train.transforms.insert(0, Augmentation(autoaug_policy()))
        elif C.get()['aug'] in ['default', "clean", "nonorm", "nocut"]:
            pass
        else:
            raise ValueError('not found augmentations. %s' % C.get()['aug'])

    if C.get()['cutout'] > 0 and C.get()['aug'] != "nocut":
        transform_train.transforms.append(CutoutDefault(C.get()['cutout']))
    if C.get()['aug'] == "clean":
        transform_train = transform_test
    elif C.get()['aug'] == "nonorm":
        transform_train = transforms.Compose([
            transforms.ToTensor()
        ])
    train_idx = valid_idx = None
    if dataset == 'cifar10':
        if isinstance(C.get()['aug'], dict):
            total_trainset = GrAugCIFAR10(root=dataroot, gr_assign=gr_assign, gr_policies=C.get()['aug'], train=True, download=False, transform=transform_train)
        else:
            total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=False, transform=transform_train)
        testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=False, transform=transform_test)
    elif dataset == 'reduced_cifar10':
        if isinstance(C.get()['aug'], dict):
            total_trainset = GrAugCIFAR10(root=dataroot, gr_assign=gr_assign, gr_policies=C.get()['aug'], train=True, download=False, transform=transform_train)
        else:
            total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=False, transform=transform_train)
        sss = StratifiedShuffleSplit(n_splits=5, train_size=4000, random_state=0)   # 4000 trainset
        sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
        for _ in range(split_idx+1):
            train_idx, valid_idx = next(sss)

        testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=False, transform=transform_test)
    elif dataset == 'cifar100':
        if isinstance(C.get()['aug'], dict):
            total_trainset = GrAugData("CIFAR100", root=dataroot, gr_assign=gr_assign, gr_policies=C.get()['aug'], train=True, download=False, transform=transform_train)
        else:
            total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=False, transform=transform_train)
        testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=False, transform=transform_test)
    elif dataset == 'svhn': #TODO
        trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=False, transform=transform_train)
        extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=False, transform=transform_train)
        total_trainset = ConcatDataset([trainset, extraset])
        testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=False, transform=transform_test)
    elif dataset == 'reduced_svhn':
        if isinstance(C.get()['aug'], dict):
            total_trainset = GrAugData("SVHN", root=dataroot, gr_assign=gr_assign, gr_policies=C.get()['aug'], split='train', download=False, transform=transform_train)
        else:
            total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=False, transform=transform_train)
        sss = StratifiedShuffleSplit(n_splits=5, train_size=1000, test_size=7325, random_state=0)
        sss = sss.split(list(range(len(total_trainset))), total_trainset.labels)
        for _ in range(split_idx+1):
            train_idx, valid_idx = next(sss)
        # targets = [total_trainset.labels[idx] for idx in train_idx]
        # total_trainset = Subset(total_trainset, train_idx)
        # total_trainset.targets = targets

        testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=False, transform=transform_test)
    elif dataset == 'imagenet':
        total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train)
        testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)

        # compatibility
        total_trainset.targets = [lb for _, lb in total_trainset.samples]
    elif dataset == 'reduced_imagenet':
        # randomly chosen indices
        # idx120 = sorted(random.sample(list(range(1000)), k=120))
        idx120 = [16, 23, 52, 57, 76, 93, 95, 96, 99, 121, 122, 128, 148, 172, 181, 189, 202, 210, 232, 238, 257, 258, 259, 277, 283, 289, 295, 304, 307, 318, 322, 331, 337, 338, 345, 350, 361, 375, 376, 381, 388, 399, 401, 408, 424, 431, 432, 440, 447, 462, 464, 472, 483, 497, 506, 512, 530, 541, 553, 554, 557, 564, 570, 584, 612, 614, 619, 626, 631, 632, 650, 657, 658, 660, 674, 675, 680, 682, 691, 695, 699, 711, 734, 736, 741, 754, 757, 764, 769, 770, 780, 781, 787, 797, 799, 811, 822, 829, 830, 835, 837, 842, 843, 845, 873, 883, 897, 900, 902, 905, 913, 920, 925, 937, 938, 940, 941, 944, 949, 959]
        total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train)
        testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)

        # compatibility
        total_trainset.targets = [lb for _, lb in total_trainset.samples]

        sss = StratifiedShuffleSplit(n_splits=1, test_size=len(total_trainset) - 50000, random_state=0)  # 4000 trainset
        sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
        train_idx, valid_idx = next(sss)

        # filter out
        train_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, train_idx))
        valid_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, valid_idx))
        test_idx = list(filter(lambda x: testset.samples[x][1] in idx120, range(len(testset))))

        targets = [idx120.index(total_trainset.targets[idx]) for idx in train_idx]
        for idx in range(len(total_trainset.samples)):
            if total_trainset.samples[idx][1] not in idx120:
                continue
            total_trainset.samples[idx] = (total_trainset.samples[idx][0], idx120.index(total_trainset.samples[idx][1]))
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        for idx in range(len(testset.samples)):
            if testset.samples[idx][1] not in idx120:
                continue
            testset.samples[idx] = (testset.samples[idx][0], idx120.index(testset.samples[idx][1]))
        testset = Subset(testset, test_idx)
        print('reduced_imagenet train=', len(total_trainset))
    elif dataset == "cifar10_svhn":
        if isinstance(C.get()['aug'], dict):
            # last stage: benchmark test
            total_trainset = GrAugMix(dataset.split("_"), gr_assign=gr_assign, gr_policies=C.get()['aug'], root=dataroot, train=True, download=False, transform=transform_train, gr_ids=gr_ids)
        else:
            # eval_tta & childnet training
            total_trainset = GrAugMix(dataset.split("_"), root=dataroot, train=True, download=False, transform=transform_train)
        testset = GrAugMix(dataset.split("_"), root=dataroot, train=False, download=False, transform=transform_test)
    else:
        raise ValueError('invalid dataset name=%s' % dataset)

    if not hasattr(total_trainset, "gr_ids"):
        total_trainset.gr_ids = None
    if gr_ids is not None:
        total_trainset.gr_ids = gr_ids
    if gr_assign is not None and total_trainset.gr_ids is None:
        # eval_tta3
        temp_trainset = copy.deepcopy(total_trainset)
        # temp_trainset.transform = transform_test # just normalize
        temp_loader = torch.utils.data.DataLoader(
        temp_trainset, batch_size=batch, shuffle=False, num_workers=4,
        drop_last=False)
        gr_dist = gr_assign(temp_loader)
        gr_ids = torch.max(gr_dist)[1].numpy()

    if split > 0.0:
        if train_idx is None or valid_idx is None:
            # filter by split ratio
            sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0)
            sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
            for _ in range(split_idx + 1):
                train_idx, valid_idx = next(sss)

        if gr_id is not None:
            # filter by group
            idx2gr = total_trainset.gr_ids
            ps = PredefinedSplit(idx2gr)
            ps = ps.split()
            for _ in range(gr_id + 1):
                _, gr_split_idx = next(ps)
            train_idx = [idx for idx in train_idx if idx in gr_split_idx]
            valid_idx = [idx for idx in valid_idx if idx in gr_split_idx]

        if target_lb >= 0:
            train_idx = [i for i in train_idx if total_trainset.targets[i] == target_lb]
            valid_idx = [i for i in valid_idx if total_trainset.targets[i] == target_lb]

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetSampler(valid_idx) if not rand_val else SubsetRandomSampler(valid_idx)

        if multinode:
            train_sampler = torch.utils.data.distributed.DistributedSampler(Subset(total_trainset, train_idx), num_replicas=dist.get_world_size(), rank=dist.get_rank())
    else:
        train_sampler = None
        valid_sampler = SubsetSampler([])

        if gr_id is not None:
            # filter by group
            idx2gr = total_trainset.gr_ids
            ps = PredefinedSplit(idx2gr)
            ps = ps.split()
            for _ in range(gr_id + 1):
                _, gr_split_idx = next(ps)
            targets = [total_trainset.targets[idx] for idx in gr_split_idx]
            total_trainset = Subset(total_trainset, gr_split_idx)
            total_trainset.targets = targets

        if train_idx is not None and valid_idx is not None:
            if dataset in ["svhn", "reduced_svhn"]:
                targets = [total_trainset.labels[idx] for idx in train_idx]
            else:
                targets = [total_trainset.targets[idx] for idx in train_idx]
            total_trainset = Subset(total_trainset, train_idx)
            total_trainset.targets = targets

        if multinode:
            train_sampler = torch.utils.data.distributed.DistributedSampler(total_trainset, num_replicas=dist.get_world_size(), rank=dist.get_rank())
            logger.info(f'----- dataset with DistributedSampler  {dist.get_rank()}/{dist.get_world_size()}')


    trainloader = torch.utils.data.DataLoader(
        total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=8 if torch.cuda.device_count()==8 else 4, pin_memory=True,
        sampler=train_sampler, drop_last=True)
    validloader = torch.utils.data.DataLoader(
        total_trainset, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True,
        sampler=valid_sampler, drop_last=False if not rand_val else True)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch, shuffle=False, num_workers=8 if torch.cuda.device_count()==8 else 4, pin_memory=True,
        drop_last=False
    )
    return train_sampler, trainloader, validloader, testloader