예제 #1
0
    def test_split_weighted(self):
        split = tfds.Split.TEST + tfds.Split.TRAIN
        split1, split2 = split.subsplit(weighted=[2, 1])

        self.assertEqual(self._info(split1), [
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="test", num_shards=2),
                slice_value=slice(0, 66),
            ),
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="train", num_shards=10),
                slice_value=slice(0, 66),
            ),
        ])

        self.assertEqual(self._info(split2), [
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="test", num_shards=2),
                slice_value=slice(66, 100),
            ),
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="train", num_shards=10),
                slice_value=slice(66, 100),
            ),
        ])
예제 #2
0
    def test_split_slice_merge(self):

        # Slice, then merge
        train = tfds.Split.TRAIN
        test = tfds.Split.TEST
        split = test.subsplit(tfds.percent[30:40]) + train

        # List sorted so always deterministic
        self.assertEqual(self._info(split), [
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="test", num_shards=2),
                slice_value=slice(30, 40),
            ),
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="train", num_shards=10),
                slice_value=None,
            ),
        ])
예제 #3
0
    def test_split_merge_slice(self):

        # Merge, then slice (then merge)
        split = tfds.Split.TEST + tfds.Split.TRAIN
        split = split.subsplit(tfds.percent[30:40])
        split = split + tfds.Split("custom").subsplit(tfds.percent[:15])

        # List sorted so always deterministic
        self.assertEqual(self._info(split), [
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="custom", num_shards=2),
                slice_value=slice(None, 15),
            ),
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="test", num_shards=2),
                slice_value=slice(30, 40),
            ),
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="train", num_shards=10),
                slice_value=slice(30, 40),
            ),
        ])
예제 #4
0
    def test_split_k(self):
        split = tfds.Split.TEST + tfds.Split.TRAIN
        split1, split2, split3 = split.subsplit(k=3)

        self.assertEqual(self._info(split1), [
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="test", num_shards=2),
                slice_value=slice(0, 33),
            ),
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="train", num_shards=10),
                slice_value=slice(0, 33),
            ),
        ])

        self.assertEqual(self._info(split2), [
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="test", num_shards=2),
                slice_value=slice(33, 66),
            ),
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="train", num_shards=10),
                slice_value=slice(33, 66),
            ),
        ])

        self.assertEqual(self._info(split3), [
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="test", num_shards=2),
                slice_value=slice(66, 100),
            ),
            splits.SlicedSplitInfo(
                split_info=tfds.core.SplitInfo(name="train", num_shards=10),
                slice_value=slice(66, 100),
            ),
        ])