def test_seed_class_splitter():
    dataset_transform = ClassSplitter(shuffle=True,
        num_train_per_class=5, num_test_per_class=5)
    dataset = Sinusoid(10, num_tasks=1000, noise_std=0.1,
        dataset_transform=dataset_transform)
    dataset.seed(1)

    expected_train_inputs = np.array([1.08565437,-1.56211897,4.62078213,-2.03870077,0.76977846])
    expected_train_targets = np.array([-0.00309463,-1.37650356,-0.9346262,-0.1031986,-0.4698061])

    expected_test_inputs = np.array([-2.48340416,3.75388738,-3.15504396,0.09898378,0.32922559])
    expected_test_targets = np.array([0.73113509,0.91773121,1.86656819,-1.61885041,-1.52508997])

    task = dataset[0]
    train_dataset, test_dataset = task['train'], task['test']

    assert len(train_dataset) == 5
    assert len(test_dataset) == 5

    for i, (train_input, train_target) in enumerate(train_dataset):
        assert np.isclose(train_input, expected_train_inputs[i])
        assert np.isclose(train_target, expected_train_targets[i])

    for i, (test_input, test_target) in enumerate(test_dataset):
        assert np.isclose(test_input, expected_test_inputs[i])
        assert np.isclose(test_target, expected_test_targets[i])
示例#2
0
文件: helpers.py 项目: struemya/INR
def sinusoid(shots, shuffle=True, test_shots=None, seed=None, **kwargs):
    """Helper function to create a meta-dataset for the Sinusoid toy dataset.

    Parameters
    ----------
    shots : int
        Number of (training) examples in each task. This corresponds to `k` in
        `k-shot` classification.

    shuffle : bool (default: `True`)
        Shuffle the examples when creating the tasks.

    test_shots : int, optional
        Number of test examples in each task. If `None`, then the number of test
        examples is equal to the number of training examples in each task.

    seed : int, optional
        Random seed to be used in the meta-dataset.

    kwargs
        Additional arguments passed to the `Sinusoid` class.

    See also
    --------
    `torchmeta.toy.Sinusoid` : Meta-dataset for the Sinusoid toy dataset.
    """
    if 'num_samples_per_task' in kwargs:
        warnings.warn(
            'Both arguments `shots` and `num_samples_per_task` were '
            'set in the helper function for the number of samples in each task. '
            'Ignoring the argument `shots`.',
            stacklevel=2)
        if test_shots is not None:
            shots = kwargs['num_samples_per_task'] - test_shots
            if shots <= 0:
                raise ValueError(
                    'The argument `test_shots` ({0}) is greater '
                    'than the number of samples per task ({1}). Either use the '
                    'argument `shots` instead of `num_samples_per_task`, or '
                    'increase the value of `num_samples_per_task`.'.format(
                        test_shots, kwargs['num_samples_per_task']))
        else:
            shots = kwargs['num_samples_per_task'] // 2
    if test_shots is None:
        test_shots = shots

    dataset = Sinusoid(num_samples_per_task=shots + test_shots, **kwargs)
    dataset = ClassSplitter(dataset,
                            shuffle=shuffle,
                            num_train_per_class=shots,
                            num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset