예제 #1
0
 def test_append_to_different_dataset_fails(self):
     data_dir = tempfile.mkdtemp('data_dir')
     writer = sequential_writer.SequentialWriter(
         ds_info=_dataset_info(data_dir), max_examples_per_shard=3)
     writer.initialize_splits(['train'])
     writer.close_splits(['train'])
     with self.assertRaises(ValueError):
         sequential_writer.SequentialWriter(_dataset_info(
             data_dir, 'new_name'),
                                            max_examples_per_shard=3,
                                            overwrite=False)
예제 #2
0
 def test_fails_to_initialize_same_split_twice(self):
     data_dir = tempfile.mkdtemp('data_dir')
     writer = sequential_writer.SequentialWriter(
         ds_info=_dataset_info(data_dir), max_examples_per_shard=1)
     writer.initialize_splits(['train'])
     with self.assertRaises(KeyError):
         writer.initialize_splits(['train'])
예제 #3
0
 def test_closing_a_closed_split_is_a_noop(self):
     data_dir = tempfile.mkdtemp('data_dir')
     writer = sequential_writer.SequentialWriter(
         ds_info=_dataset_info(data_dir), max_examples_per_shard=1)
     writer.initialize_splits(['train'])
     writer.close_splits(['train'])
     writer.close_splits(['train'])
예제 #4
0
    def test_writes_multiple_splits_sequentially(self):
        num_examples_train = 5
        num_examples_eval = 7
        max_examples_per_shard = 3
        num_expected_shards_train = 2
        num_expected_shards_eval = 3

        data_dir = tempfile.mkdtemp('data_dir')

        writer = sequential_writer.SequentialWriter(
            ds_info=_dataset_info(data_dir),
            max_examples_per_shard=max_examples_per_shard)
        writer.initialize_splits(['train'])
        writer.add_examples({
            'train':
            list(tfds.as_numpy(_generate_split(num_examples_train))),
        })
        writer.close_splits(['train'])
        writer.initialize_splits(['eval'])
        writer.add_examples({
            'eval':
            list(tfds.as_numpy(_generate_split(num_examples_eval))),
        })
        writer.close_splits(['eval'])
        ds_builder = tfds.builder_from_directory(data_dir)
        self.assertEqual(ds_builder.info.splits['train'].num_shards,
                         num_expected_shards_train)
        self.assertEqual(ds_builder.info.splits['eval'].num_shards,
                         num_expected_shards_eval)
        counter = ds_builder.as_dataset(split='train').reduce(
            0, lambda v, _: v + 1).numpy()
        self.assertEqual(counter, num_examples_train)
        counter = ds_builder.as_dataset(split='eval').reduce(
            0, lambda v, _: v + 1).numpy()
        self.assertEqual(counter, num_examples_eval)
예제 #5
0
 def test_append_to_non_existent_works(self):
     data_dir = tempfile.mkdtemp('data_dir')
     writer = sequential_writer.SequentialWriter(
         ds_info=_dataset_info(data_dir),
         max_examples_per_shard=3,
         overwrite=False)
     writer.initialize_splits(['train'])
     writer.close_splits(['train'])
예제 #6
0
 def test_fails_to_append_to_nonexisting_split(self):
     writer = sequential_writer.SequentialWriter(
         ds_info=_dataset_info('/unused/dir'), max_examples_per_shard=1)
     with self.assertRaises(KeyError):
         writer.add_examples({
             'train':
             list(tfds.as_numpy(_generate_split(1))),
         })
예제 #7
0
 def test_fails_to_append_to_closed_split(self):
     data_dir = tempfile.mkdtemp('data_dir')
     writer = sequential_writer.SequentialWriter(
         ds_info=_dataset_info(data_dir), max_examples_per_shard=1)
     writer.initialize_splits(['train'])
     writer.close_splits(['train'])
     with self.assertRaises(ValueError):
         writer.add_examples({
             'train':
             list(tfds.as_numpy(_generate_split(1))),
         })
예제 #8
0
    def test_overwrites(self):
        data_dir = tempfile.mkdtemp('data_dir')
        ds_info = _dataset_info(data_dir)
        writer = sequential_writer.SequentialWriter(ds_info=ds_info,
                                                    max_examples_per_shard=3)
        writer.initialize_splits(['train'])
        writer.add_examples({'train': list(tfds.as_numpy(_generate_split(4)))})
        writer.close_splits(
            ['train'])  # The split should have 2 shards with 4 examples

        new_writer = sequential_writer.SequentialWriter(
            ds_info=ds_info, max_examples_per_shard=2, overwrite=True)
        new_writer.initialize_splits(['train'], fail_if_exists=True)
        new_writer.add_examples(
            {'train': list(tfds.as_numpy(_generate_split(3)))})
        new_writer.close_splits(['train'])  # We added 2 shards with 3 examples

        ds_builder = tfds.builder_from_directory(data_dir)
        self.assertEqual(ds_builder.info.splits['train'].num_shards, 2)
        counter = ds_builder.as_dataset(split='train').reduce(
            0, lambda v, _: v + 1).numpy()
        self.assertEqual(counter, 3)
예제 #9
0
    def test_generates_correct_shards(self, num_examples,
                                      max_examples_per_shard,
                                      num_expected_shards):
        data_dir = tempfile.mkdtemp('data_dir')
        writer = sequential_writer.SequentialWriter(
            ds_info=_dataset_info(data_dir),
            max_examples_per_shard=max_examples_per_shard)
        writer.initialize_splits(['train'])
        for example in tfds.as_numpy(_generate_split(num_examples)):
            writer.add_examples({'train': [example]})
        writer.close_splits(['train'])

        ds_builder = tfds.builder_from_directory(data_dir)
        self.assertEqual(ds_builder.info.splits['train'].num_shards,
                         num_expected_shards)
        counter = ds_builder.as_dataset(split='train').reduce(
            0, lambda v, _: v + 1).numpy()
        self.assertEqual(counter, num_examples)
예제 #10
0
 def test_initializes_same_split_twice(self):
     data_dir = tempfile.mkdtemp('data_dir')
     writer = sequential_writer.SequentialWriter(
         ds_info=_dataset_info(data_dir), max_examples_per_shard=1)
     writer.initialize_splits(['train'])
     writer.initialize_splits(['train'], fail_if_exists=False)
예제 #11
0
 def test_fails_to_close_a_nonexisting_split(self):
     writer = sequential_writer.SequentialWriter(
         ds_info=_dataset_info('/unused/dir'), max_examples_per_shard=1)
     with self.assertRaises(KeyError):
         writer.close_splits(['train'])