def cifar10(dataset_root, split: str = 'train'):
     assert split in ('train', 'val')
     dataset = CIFAR10Mixed(
         root=dataset_root,
         split=split,
         transform=amdim_transforms.AMDIMTrainTransformsCIFAR10(),
         download=True,
     )
     return dataset
示例#2
0
 def cifar10_tiny(dataset_root, split: str = "train"):
     assert split in ("train", "val")
     dataset = CIFAR10Mixed(
         root=dataset_root,
         split=split,
         transform=amdim_transforms.AMDIMTrainTransformsCIFAR10(),
         download=True,
         nb_labeled_per_class=50,
     )
     return dataset