def omniglot(folder, shots, ways, shuffle=True, test_shots=None, seed=None, **kwargs): """Helper function to create a meta-dataset for the Omniglot dataset. Parameters ---------- folder : string Root directory where the dataset folder `omniglot` exists. shots : int Number of (training) examples per class in each task. This corresponds to `k` in `k-shot` classification. ways : int Number of classes per task. This corresponds to `N` in `N-way` classification. shuffle : bool (default: `True`) Shuffle the examples when creating the tasks. test_shots : int, optional Number of test examples per class in each task. If `None`, then the number of test examples is equal to the number of training examples per class. seed : int, optional Random seed to be used in the meta-dataset. kwargs Additional arguments passed to the `Omniglot` class. See also -------- `datasets.Omniglot` : Meta-dataset for the Omniglot dataset. """ if 'num_classes_per_task' in kwargs: warnings.warn('Both arguments `ways` and `num_classes_per_task` were ' 'set in the helper function for the number of classes per task. ' 'Ignoring the argument `ways`.', stacklevel=2) ways = kwargs['num_classes_per_task'] if 'transform' not in kwargs: kwargs['transform'] = Compose([Resize(28), ToTensor()]) if 'target_transform' not in kwargs: kwargs['target_transform'] = Categorical(ways) if 'class_augmentations' not in kwargs: kwargs['class_augmentations'] = [Rotation([90, 180, 270])] if test_shots is None: test_shots = shots dataset = Omniglot(folder, num_classes_per_task=ways, **kwargs) dataset = ClassSplitter(dataset, shuffle=shuffle, num_train_per_class=shots, num_test_per_class=test_shots) dataset.seed(seed) return dataset
def create_og_data_loader( root, meta_split, k_way, n_shot, input_size, n_query, batch_size, num_workers, download=False, use_vinyals_split=False, seed=None, ): """Create a torchmeta BatchMetaDataLoader for Omniglot Args: root: Path to Omniglot data root folder (containing an 'omniglot'` subfolder with the preprocess json-Files or downloaded zip-files). meta_split: see torchmeta.datasets.Omniglot k_way: Number of classes per task n_shot: Number of samples per class input_size: Images are resized to this size. n_query: Number of test images per class batch_size: Meta batch size num_workers: Number of workers for data preprocessing download: Download (and dataset specific preprocessing that needs to be done on the downloaded files). use_vinyals_split: see torchmeta.datasets.Omniglot seed: Seed to be used in the meta-dataset Returns: A torchmeta :class:`BatchMetaDataLoader` object. """ dataset = Omniglot( root, num_classes_per_task=k_way, transform=Compose([Resize(input_size), ToTensor()]), target_transform=Categorical(num_classes=k_way), class_augmentations=[Rotation([90, 180, 270])], meta_split=meta_split, download=download, use_vinyals_split=use_vinyals_split, ) dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=n_shot, num_test_per_class=n_query) dataset.seed = seed dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) return dataloader