])
# test_preprocessor = transforms.ToTensor()
dataset = args.dataset
if dataset == "CIFAR-10":
    train_dataset = CIFAR10(IMAGE_DATA_ROOT[dataset],
                            train=True,
                            transform=train_preprocessor)
    # test_dataset =  CIFAR10(IMAGE_DATA_ROOT[dataset], train=False, transform=test_preprocessor)
elif dataset == "CIFAR-100":
    train_dataset = CIFAR100(IMAGE_DATA_ROOT[dataset],
                             train=True,
                             transform=train_preprocessor)
    # test_dataset = CIFAR100(IMAGE_DATA_ROOT[dataset], train=False, transform=test_preprocessor)
elif dataset == "ImageNet":
    train_preprocessor = DataLoaderMaker.get_preprocessor(IMAGE_SIZE[dataset],
                                                          True,
                                                          center_crop=False)
    # test_preprocessor = DataLoaderMaker.get_preprocessor(IMAGE_SIZE[dataset], False, center_crop=True)
    train_dataset = ImageFolder(IMAGE_DATA_ROOT[dataset] + "/train",
                                transform=train_preprocessor)
elif dataset == "TinyImageNet":
    train_dataset = TinyImageNet(IMAGE_DATA_ROOT[dataset],
                                 train_preprocessor,
                                 train=True)

batch_size = args.batch_size
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=0)
    def __init__(self, tot_num_tasks, dataset, inner_batch_size, protocol):
        """
        Args:
            num_samples_per_class: num samples to generate "per class" in one batch
            batch_size: size of meta batch size (e.g. number of functions)
        """
        self.img_size = IMAGE_SIZE[dataset]
        self.dataset = dataset

        if protocol == SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II:
            self.model_names = MODELS_TRAIN_STANDARD[self.dataset]
        elif protocol == SPLIT_DATA_PROTOCOL.TRAIN_II_TEST_I:
            self.model_names = MODELS_TEST_STANDARD[self.dataset]
        elif protocol == SPLIT_DATA_PROTOCOL.TRAIN_ALL_TEST_ALL:
            self.model_names = MODELS_TRAIN_STANDARD[
                self.dataset] + MODELS_TEST_STANDARD[self.dataset]

        self.model_dict = {}
        for arch in self.model_names:
            if StandardModel.check_arch(arch, dataset):
                model = StandardModel(dataset, arch, no_grad=False).eval()
                if dataset != "ImageNet":
                    model = model.cuda()
                self.model_dict[arch] = model
        is_train = True
        preprocessor = DataLoaderMaker.get_preprocessor(
            IMAGE_SIZE[dataset], is_train)
        if dataset == "CIFAR-10":
            train_dataset = CIFAR10(IMAGE_DATA_ROOT[dataset],
                                    train=is_train,
                                    transform=preprocessor)
        elif dataset == "CIFAR-100":
            train_dataset = CIFAR100(IMAGE_DATA_ROOT[dataset],
                                     train=is_train,
                                     transform=preprocessor)
        elif dataset == "MNIST":
            train_dataset = MNIST(IMAGE_DATA_ROOT[dataset],
                                  train=is_train,
                                  transform=preprocessor)
        elif dataset == "FashionMNIST":
            train_dataset = FashionMNIST(IMAGE_DATA_ROOT[dataset],
                                         train=is_train,
                                         transform=preprocessor)
        elif dataset == "TinyImageNet":
            train_dataset = TinyImageNet(IMAGE_DATA_ROOT[dataset],
                                         preprocessor,
                                         train=is_train)
        elif dataset == "ImageNet":
            preprocessor = DataLoaderMaker.get_preprocessor(
                IMAGE_SIZE[dataset], is_train, center_crop=True)
            sub_folder = "/train" if is_train else "/validation"  # Note that ImageNet uses pretrainedmodels.utils.TransformImage to apply transformation
            train_dataset = ImageFolder(IMAGE_DATA_ROOT[dataset] + sub_folder,
                                        transform=preprocessor)
        self.train_dataset = train_dataset
        self.total_num_images = len(train_dataset)
        self.all_tasks = dict()
        all_images_indexes = np.arange(self.total_num_images).tolist()
        for i in range(tot_num_tasks):
            self.all_tasks[i] = {
                "image": random.sample(all_images_indexes, inner_batch_size),
                "arch": random.choice(list(self.model_dict.keys()))
            }