示例#1
0
def omniglot(folder, shots, ways, shuffle=True, test_shots=None,
             seed=None, **kwargs):
    """Helper function to create a meta-dataset for the Omniglot dataset.

    Parameters
    ----------
    folder : string
        Root directory where the dataset folder `omniglot` exists.

    shots : int
        Number of (training) examples per class in each task. This corresponds 
        to `k` in `k-shot` classification.

    ways : int
        Number of classes per task. This corresponds to `N` in `N-way` 
        classification.

    shuffle : bool (default: `True`)
        Shuffle the examples when creating the tasks.

    test_shots : int, optional
        Number of test examples per class in each task. If `None`, then the 
        number of test examples is equal to the number of training examples per 
        class.

    seed : int, optional
        Random seed to be used in the meta-dataset.

    kwargs
        Additional arguments passed to the `Omniglot` class.

    See also
    --------
    `datasets.Omniglot` : Meta-dataset for the Omniglot dataset.
    """
    if 'num_classes_per_task' in kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = Compose([Resize(28), ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if 'class_augmentations' not in kwargs:
        kwargs['class_augmentations'] = [Rotation([90, 180, 270])]
    if test_shots is None:
        test_shots = shots

    dataset = Omniglot(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
示例#2
0
def create_og_data_loader(
    root,
    meta_split,
    k_way,
    n_shot,
    input_size,
    n_query,
    batch_size,
    num_workers,
    download=False,
    use_vinyals_split=False,
    seed=None,
):
    """Create a torchmeta BatchMetaDataLoader for Omniglot

    Args:
        root: Path to Omniglot data root folder (containing an 'omniglot'` subfolder with the
            preprocess json-Files or downloaded zip-files).
        meta_split: see torchmeta.datasets.Omniglot
        k_way: Number of classes per task
        n_shot: Number of samples per class
        input_size: Images are resized to this size.
        n_query: Number of test images per class
        batch_size: Meta batch size
        num_workers: Number of workers for data preprocessing
        download: Download (and dataset specific preprocessing that needs to be done on the
            downloaded files).
        use_vinyals_split: see torchmeta.datasets.Omniglot
        seed: Seed to be used in the meta-dataset

    Returns:
        A torchmeta :class:`BatchMetaDataLoader` object.
    """
    dataset = Omniglot(
        root,
        num_classes_per_task=k_way,
        transform=Compose([Resize(input_size), ToTensor()]),
        target_transform=Categorical(num_classes=k_way),
        class_augmentations=[Rotation([90, 180, 270])],
        meta_split=meta_split,
        download=download,
        use_vinyals_split=use_vinyals_split,
    )
    dataset = ClassSplitter(dataset,
                            shuffle=True,
                            num_train_per_class=n_shot,
                            num_test_per_class=n_query)
    dataset.seed = seed
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=batch_size,
                                     num_workers=num_workers,
                                     shuffle=True)
    return dataloader