def test_misconfiguration(): with pytest.raises(MisconfigurationException, match="[0, 99]"): SplitDataset(range(100), indices=[100]) with pytest.raises(MisconfigurationException, match="[0, 49]"): SplitDataset(range(50), indices=[-1]) with pytest.raises(MisconfigurationException, match="[0, 49]"): SplitDataset(list(range(50)) + list(range(50)), indices=[-1]) with pytest.raises(MisconfigurationException, match="[0, 99]"): SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True) with pytest.raises(MisconfigurationException, match="indices should be a list"): SplitDataset(list(range(100)), indices="not a list")
def _split_train_val( train_dataset: Dataset, val_split: float, ) -> Tuple[Any, Any]: """Utility function for splitting the training dataset into a disjoint subset of training samples and validation samples. Args: train_dataset: A instance of a :class:`torch.utils.data.Dataset`. val_split: A float between 0 and 1 determining the number fraction of samples that should go into the validation split Returns: A tuple containing the training and validation datasets """ if not isinstance(val_split, float) or (isinstance(val_split, float) and val_split > 1 or val_split < 0): raise MisconfigurationException( f"`val_split` should be a float between 0 and 1. Found {val_split}." ) if isinstance(train_dataset, IterableInput): raise MisconfigurationException( "`val_split` should be `None` when the dataset is built with an IterableDataset." ) val_num_samples = int(len(train_dataset) * val_split) indices = list(range(len(train_dataset))) np.random.shuffle(indices) val_indices = indices[:val_num_samples] train_indices = indices[val_num_samples:] return ( SplitDataset(train_dataset, train_indices, running_stage=RunningStage.TRAINING, use_duplicated_indices=True), SplitDataset(train_dataset, val_indices, running_stage=RunningStage.VALIDATING, use_duplicated_indices=True), )
def test_split_dataset(): train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1) assert len(train_ds) == 90 assert len(val_ds) == 10 assert len(np.unique(train_ds.indices)) == len(train_ds.indices) class Dataset: def __init__(self): self.data = [0, 1, 2] self.name = "something" self.is_passed_down = False def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) split_dataset = SplitDataset(Dataset(), indices=[0]) assert split_dataset.name == "something" split_dataset.is_passed_down = True assert not split_dataset.dataset.is_passed_down
def test_split_dataset(tmpdir): train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1) assert len(train_ds) == 90 assert len(val_ds) == 10 assert len(np.unique(train_ds.indices)) == len(train_ds.indices) with pytest.raises(MisconfigurationException, match="[0, 99]"): SplitDataset(range(100), indices=[100]) with pytest.raises(MisconfigurationException, match="[0, 49]"): SplitDataset(range(50), indices=[-1]) with pytest.raises(MisconfigurationException, match="[0, 49]"): SplitDataset(list(range(50)) + list(range(50)), indices=[-1]) with pytest.raises(MisconfigurationException, match="[0, 99]"): SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True) class Dataset: def __init__(self): self.data = [0, 1, 2] self.name = "something" def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) split_dataset = SplitDataset(Dataset(), indices=[0]) assert split_dataset.name == "something" assert split_dataset._INTERNAL_KEYS == ("dataset", "indices", "data") split_dataset.is_passed_down = True assert split_dataset.dataset.is_passed_down