コード例 #1
0
def test_misconfiguration():
    with pytest.raises(MisconfigurationException, match="[0, 99]"):
        SplitDataset(range(100), indices=[100])

    with pytest.raises(MisconfigurationException, match="[0, 49]"):
        SplitDataset(range(50), indices=[-1])

    with pytest.raises(MisconfigurationException, match="[0, 49]"):
        SplitDataset(list(range(50)) + list(range(50)), indices=[-1])

    with pytest.raises(MisconfigurationException, match="[0, 99]"):
        SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True)

    with pytest.raises(MisconfigurationException, match="indices should be a list"):
        SplitDataset(list(range(100)), indices="not a list")
コード例 #2
0
    def _split_train_val(
        train_dataset: Dataset,
        val_split: float,
    ) -> Tuple[Any, Any]:
        """Utility function for splitting the training dataset into a disjoint subset of training samples and
        validation samples.

        Args:
            train_dataset: A instance of a :class:`torch.utils.data.Dataset`.
            val_split: A float between 0 and 1 determining the number fraction of samples that should go into the
                validation split

        Returns:
            A tuple containing the training and validation datasets
        """

        if not isinstance(val_split, float) or (isinstance(val_split, float)
                                                and val_split > 1
                                                or val_split < 0):
            raise MisconfigurationException(
                f"`val_split` should be a float between 0 and 1. Found {val_split}."
            )

        if isinstance(train_dataset, IterableInput):
            raise MisconfigurationException(
                "`val_split` should be `None` when the dataset is built with an IterableDataset."
            )

        val_num_samples = int(len(train_dataset) * val_split)
        indices = list(range(len(train_dataset)))
        np.random.shuffle(indices)
        val_indices = indices[:val_num_samples]
        train_indices = indices[val_num_samples:]
        return (
            SplitDataset(train_dataset,
                         train_indices,
                         running_stage=RunningStage.TRAINING,
                         use_duplicated_indices=True),
            SplitDataset(train_dataset,
                         val_indices,
                         running_stage=RunningStage.VALIDATING,
                         use_duplicated_indices=True),
        )
コード例 #3
0
def test_split_dataset():
    train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1)
    assert len(train_ds) == 90
    assert len(val_ds) == 10
    assert len(np.unique(train_ds.indices)) == len(train_ds.indices)

    class Dataset:
        def __init__(self):
            self.data = [0, 1, 2]
            self.name = "something"
            self.is_passed_down = False

        def __getitem__(self, index):
            return self.data[index]

        def __len__(self):
            return len(self.data)

    split_dataset = SplitDataset(Dataset(), indices=[0])
    assert split_dataset.name == "something"

    split_dataset.is_passed_down = True
    assert not split_dataset.dataset.is_passed_down
コード例 #4
0
def test_split_dataset(tmpdir):

    train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1)
    assert len(train_ds) == 90
    assert len(val_ds) == 10
    assert len(np.unique(train_ds.indices)) == len(train_ds.indices)

    with pytest.raises(MisconfigurationException, match="[0, 99]"):
        SplitDataset(range(100), indices=[100])

    with pytest.raises(MisconfigurationException, match="[0, 49]"):
        SplitDataset(range(50), indices=[-1])

    with pytest.raises(MisconfigurationException, match="[0, 49]"):
        SplitDataset(list(range(50)) + list(range(50)), indices=[-1])

    with pytest.raises(MisconfigurationException, match="[0, 99]"):
        SplitDataset(list(range(50)) + list(range(50)),
                     indices=[-1],
                     use_duplicated_indices=True)

    class Dataset:
        def __init__(self):
            self.data = [0, 1, 2]
            self.name = "something"

        def __getitem__(self, index):
            return self.data[index]

        def __len__(self):
            return len(self.data)

    split_dataset = SplitDataset(Dataset(), indices=[0])
    assert split_dataset.name == "something"

    assert split_dataset._INTERNAL_KEYS == ("dataset", "indices", "data")

    split_dataset.is_passed_down = True
    assert split_dataset.dataset.is_passed_down