def test_random_split_gives_error_on_wrong_ratios(self): source_dataset = Dataset.from_iterable([DatasetItem(id=1)]) with self.assertRaises(Exception): transforms.RandomSplit(source_dataset, splits=[ ('train', 0.5), ('test', 0.7), ]) with self.assertRaises(Exception): transforms.RandomSplit(source_dataset, splits=[]) with self.assertRaises(Exception): transforms.RandomSplit(source_dataset, splits=[ ('train', -0.5), ('test', 1.5), ])
def test_random_split_gives_error_on_wrong_ratios(self): class SrcExtractor(Extractor): def __iter__(self): return iter([DatasetItem(id=1)]) with self.assertRaises(Exception): transforms.RandomSplit(SrcExtractor(), splits=[ ('train', 0.5), ('test', 0.7), ]) with self.assertRaises(Exception): transforms.RandomSplit(SrcExtractor(), splits=[]) with self.assertRaises(Exception): transforms.RandomSplit(SrcExtractor(), splits=[ ('train', -0.5), ('test', 1.5), ])
def test_random_split(self): source_dataset = Dataset.from_iterable([ DatasetItem(id=1, subset="a"), DatasetItem(id=2, subset="a"), DatasetItem(id=3, subset="b"), DatasetItem(id=4, subset="b"), DatasetItem(id=5, subset="b"), DatasetItem(id=6, subset=""), DatasetItem(id=7, subset=""), ]) actual = transforms.RandomSplit(source_dataset, splits=[ ('train', 4.0 / 7.0), ('test', 3.0 / 7.0), ]) self.assertEqual(4, len(actual.get_subset('train'))) self.assertEqual(3, len(actual.get_subset('test')))
def test_random_split(self): class SrcExtractor(Extractor): def __iter__(self): return iter([ DatasetItem(id=1, subset="a"), DatasetItem(id=2, subset="a"), DatasetItem(id=3, subset="b"), DatasetItem(id=4, subset="b"), DatasetItem(id=5, subset="b"), DatasetItem(id=6, subset=""), DatasetItem(id=7, subset=""), ]) actual = transforms.RandomSplit(SrcExtractor(), splits=[ ('train', 4.0 / 7.0), ('test', 3.0 / 7.0), ]) self.assertEqual(4, len(actual.get_subset('train'))) self.assertEqual(3, len(actual.get_subset('test')))