def test_split_weighted(self): split = tfds.Split.TEST + tfds.Split.TRAIN split1, split2 = split.subsplit(weighted=[2, 1]) self.assertEqual(self._info(split1), [ splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="test", num_shards=2), slice_value=slice(0, 66), ), splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="train", num_shards=10), slice_value=slice(0, 66), ), ]) self.assertEqual(self._info(split2), [ splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="test", num_shards=2), slice_value=slice(66, 100), ), splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="train", num_shards=10), slice_value=slice(66, 100), ), ])
def test_split_slice_merge(self): # Slice, then merge train = tfds.Split.TRAIN test = tfds.Split.TEST split = test.subsplit(tfds.percent[30:40]) + train # List sorted so always deterministic self.assertEqual(self._info(split), [ splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="test", num_shards=2), slice_value=slice(30, 40), ), splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="train", num_shards=10), slice_value=None, ), ])
def test_split_merge_slice(self): # Merge, then slice (then merge) split = tfds.Split.TEST + tfds.Split.TRAIN split = split.subsplit(tfds.percent[30:40]) split = split + tfds.Split("custom").subsplit(tfds.percent[:15]) # List sorted so always deterministic self.assertEqual(self._info(split), [ splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="custom", num_shards=2), slice_value=slice(None, 15), ), splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="test", num_shards=2), slice_value=slice(30, 40), ), splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="train", num_shards=10), slice_value=slice(30, 40), ), ])
def test_split_k(self): split = tfds.Split.TEST + tfds.Split.TRAIN split1, split2, split3 = split.subsplit(k=3) self.assertEqual(self._info(split1), [ splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="test", num_shards=2), slice_value=slice(0, 33), ), splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="train", num_shards=10), slice_value=slice(0, 33), ), ]) self.assertEqual(self._info(split2), [ splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="test", num_shards=2), slice_value=slice(33, 66), ), splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="train", num_shards=10), slice_value=slice(33, 66), ), ]) self.assertEqual(self._info(split3), [ splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="test", num_shards=2), slice_value=slice(66, 100), ), splits.SlicedSplitInfo( split_info=tfds.core.SplitInfo(name="train", num_shards=10), slice_value=slice(66, 100), ), ])