def __init__(self,
              root,
              config,
              crop_size,
              scale_size,
              baseline=False,
              gcn=True):
     super(BinClsDataSet, self).__init__()
     self.root = root
     self.config = config
     self.crop_size = crop_size
     self.scale_size = scale_size
     df = pd.DataFrame.from_csv(config)
     self.images_list = []
     for index, row in df.iterrows():
         self.images_list.append(row)
     with open('info.json', 'r') as fp:
         info = json.load(fp)
     mean_values = torch.from_numpy(
         np.array(info['mean'], dtype=np.float32) / 255)
     std_values = torch.from_numpy(
         np.array(info['std'], dtype=np.float32) / 255)
     eigen_values = torch.from_numpy(
         np.array(info['eigval'], dtype=np.float32))
     eigen_vectors = torch.from_numpy(
         np.array(info['eigvec'], dtype=np.float32))
     if gcn:
         if baseline:
             self.transform = transforms.Compose([
                 transforms.RandomCrop(crop_size),
                 transforms.Scale(299),
                 transforms.RandomHorizontalFlip(),
                 transforms.ToTensor(),
                 transforms.Normalize(mean=mean_values, std=std_values),
             ])
         else:
             self.transform = transforms.Compose([
                 transforms.RandomCrop(crop_size),
                 transforms.Scale(299),
                 transforms.RandomHorizontalFlip(),
                 PILColorJitter(),
                 transforms.ToTensor(),
                 Lighting(alphastd=0.01,
                          eigval=eigen_values,
                          eigvec=eigen_values),
                 transforms.Normalize(mean=mean_values, std=std_values),
             ])
     else:
         if baseline:
             self.transform = transforms.Compose([
                 transforms.RandomCrop(crop_size),
                 transforms.RandomHorizontalFlip(),
                 transforms.ToTensor(),
                 transforms.Normalize(mean=mean_values, std=std_values),
             ])
         else:
             self.transform = transforms.Compose([
                 transforms.RandomCrop(crop_size),
                 transforms.RandomHorizontalFlip(),
                 PILColorJitter(),
                 transforms.ToTensor(),
                 Lighting(alphastd=0.01,
                          eigval=eigen_values,
                          eigvec=eigen_values),
                 transforms.Normalize(mean=mean_values, std=std_values),
             ])
def get_loaders(root,
                batch_size,
                resolution,
                num_workers=32,
                val_batch_size=200,
                prefetch=False,
                color_jitter=0.4,
                pca=False,
                crop_pct=0.875):
    normalize = transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN,
                                     std=IMAGENET_DEFAULT_STD)
    scale_size = int(math.floor(resolution / crop_pct))

    transform_train = []
    transform_eval = []

    transform_train += [
        transforms.RandomResizedCrop(resolution,
                                     interpolation=InterpolationMode.BICUBIC),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(*(color_jitter, color_jitter, color_jitter)),
    ]

    transform_eval += [
        transforms.Resize(scale_size, interpolation=InterpolationMode.BICUBIC),
        transforms.CenterCrop(resolution),
    ]

    if not prefetch:
        transform_train += [transforms.ToTensor()]
        if pca:
            transform_train += [
                Lighting(0.1, IMAGENET_PCA['eigval'], IMAGENET_PCA['eigvec'])
            ]
        transform_train += [
            normalize,
        ]
        transform_eval += [
            transforms.ToTensor(),
            normalize,
        ]
    else:
        transform_train += [ToNumpy()]
        transform_eval += [ToNumpy()]

    transform_train = transforms.Compose(transform_train)
    transform_eval = transforms.Compose(transform_eval)

    train_dataset = ImageFolder(root + "/train", transform_train)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)

    val_dataset = ImageFolder(root + "/val", transform_eval)

    collate_fn = fast_collate if prefetch else torch.utils.data.dataloader.default_collate

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=(train_sampler is None),
                              num_workers=num_workers,
                              pin_memory=True,
                              sampler=train_sampler,
                              collate_fn=collate_fn,
                              persistent_workers=True)

    val_loader = DataLoader(val_dataset,
                            batch_size=val_batch_size,
                            shuffle=False,
                            num_workers=num_workers,
                            pin_memory=True,
                            collate_fn=collate_fn,
                            persistent_workers=True)
    if prefetch:
        train_loader = PrefetchLoader(train_loader)
        val_loader = PrefetchLoader(val_loader)

    return train_loader, val_loader