def cifar10(dataset_root, split: str = 'train'): assert split in ('train', 'val') dataset = CIFAR10Mixed( root=dataset_root, split=split, transform=amdim_transforms.AMDIMTrainTransformsCIFAR10(), download=True, ) return dataset
def cifar10_tiny(dataset_root, split: str = "train"): assert split in ("train", "val") dataset = CIFAR10Mixed( root=dataset_root, split=split, transform=amdim_transforms.AMDIMTrainTransformsCIFAR10(), download=True, nb_labeled_per_class=50, ) return dataset