def testReplicateAndShardProduceDisjointData(self, shuffle, sharding_policy): dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=shuffle) dataset = dataset.flat_map(core_readers.TFRecordDataset) graph_def = dataset._as_serialized_graph( strip_device_assignment=True, external_state_policy=options_lib.ExternalStatePolicy.WARN) options = options_lib.Options() options.experimental_distribute.auto_shard_policy = sharding_policy ds1 = distribute._RemoteDataset(graph_def, "/device:CPU:0", dataset.element_spec) ds2 = distribute._RemoteDataset(graph_def, "/device:CPU:0", dataset.element_spec) ds1 = ds1.with_options(options) ds2 = ds2.with_options(options) ds1 = distribute._AutoShardDataset(ds1, 2, 0) ds2 = distribute._AutoShardDataset(ds2, 2, 1) elems1 = set(self.getAllDatasetElements(ds1)) elems2 = set(self.getAllDatasetElements(ds2)) self.assertEmpty(elems1.intersection(elems2))
def testWorkersGreaterThanNumFiles(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 500, 499) self.assertDatasetProduces(dataset, [])
def testDatasetOfReaderDatasetsPipeline(self, batch_size): # This tests a scenario where a list_files main return multiple files # due to the glob containing wildcards. def batch(iterator, n): l = len(iterator) for i in range(0, l, n): yield iterator[i:min(i + n, l)] datasets = [] for files in batch(self._filenames, batch_size): datasets.append( dataset_ops.Dataset.list_files(files, shuffle=False).map( core_readers.TFRecordDataset)) dataset = dataset_ops.Dataset.from_tensor_slices(datasets) dataset = dataset.flat_map(lambda x: x) # Simulate additional ops in between flat_map and interleave. This should be # a no-op since if ShardDataset is placed right after flat_map, we will only # have two datasets left at this point. dataset = dataset.prefetch(1) dataset = dataset.prefetch(1) dataset = dataset.interleave(lambda x: x, cycle_length=1, num_parallel_calls=1) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected)
def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None): """Shard the input pipeline by sharding the underlying list of files. Args: dataset: A `tf.data.Dataset` instance, typically the result of a bunch of dataset transformations. num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of shards operating in parallel. Same usage as in `tf.data.Dataset.shard`. index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. Same usage as in `tf.data.Dataset.shard`. num_replicas_in_sync: An integer representing the total number of replicas across all workers. This is used in the rewrite when sharding by data. Returns: A modified `Dataset` obtained by updating the pipeline sharded by the files. The input dataset will be returned if we cannot automatically determine a good way to shard the input dataset. """ if (dataset.options().experimental_distribute.auto_shard_policy != AutoShardPolicy.OFF): if num_replicas_in_sync is None: num_replicas_in_sync = 1 if isinstance(dataset, dataset_ops.DatasetV1): return distribute._AutoShardDatasetV1(dataset, num_shards, index, num_replicas_in_sync) else: return distribute._AutoShardDataset(dataset, num_shards, index, num_replicas_in_sync) else: return dataset
def testWorkersGreaterThanNumFiles(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 500, 499) self.assertDatasetProduces(dataset, [])
def testInvalidWorkerIndex(self): dataset = dataset_ops.Dataset.list_files(self._filenames) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) with self.assertRaises(errors.InvalidArgumentError): dataset = distribute._AutoShardDataset(dataset, 2, 2) self.evaluate(self.getNext(dataset)())
def build_dataset(): dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.apply( interleave_ops.parallel_interleave( core_readers.TFRecordDataset, 10)) dataset = distribute._AutoShardDataset(dataset, 5, 3) return dataset
def testInvalidWorkerIndex(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) with self.assertRaises(errors.InvalidArgumentError): dataset = distribute._AutoShardDataset(dataset, 2, 2) self.evaluate(self.getNext(dataset)())
def testUnsupportedOpInPipeline(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = dataset.apply(unique.unique()) with self.assertRaises(errors.NotFoundError): dataset = distribute._AutoShardDataset(dataset, 2, 0) self.evaluate(self.getNext(dataset)())
def testUnsupportedOpInPipeline(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = dataset.apply(unique.unique()) with self.assertRaises(errors.NotFoundError): dataset = distribute._AutoShardDataset(dataset, 2, 0) self.evaluate(self.getNext(dataset)())
def testDirectFilenameTextLineReaderPipeline(self): dataset = core_readers.TextLineDataset(self.test_filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"%d: %d" % (f, r) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected)
def testShardByDataBeforePrefetch(self, sharding_policy): dataset = dataset_ops.Dataset.range(4) dataset = dataset.apply(testing.assert_next(["Shard", "Prefetch"])) dataset = dataset.prefetch(1) options = options_lib.Options() options.experimental_distribute.auto_shard_policy = sharding_policy dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 2, 0) self.assertDatasetProduces(dataset, [0, 2])
def testDirectFilenameTFRecordReaderPipeline(self): dataset = core_readers.TFRecordDataset(self._filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected)
def testUseLegacyRebatchWithDataSharding(self, sharding_policy, with_prefetch): # This test simulates a distributed environment with 3 workers, each with # 1 replica. dataset = dataset_ops.Dataset.range(8) dataset = dataset.batch(4) options = options_lib.Options() options.experimental_distribute.auto_shard_policy = sharding_policy dataset = dataset.with_options(options) # We expect the auto-shard rewrite to rewrite RebatchDatasetV2 to # RebatchDataset(V1) for correctness reasons. This will modify the output # of the dataset. worker_a_dataset = distribute._RebatchDataset(dataset, batch_sizes=[2, 1, 1]) if with_prefetch: worker_a_dataset = worker_a_dataset.prefetch(1) worker_a_dataset = distribute._AutoShardDataset(worker_a_dataset, 3, 0, num_replicas=3) expected = [[0, 1], [4, 5]] self.assertDatasetProduces(worker_a_dataset, expected) worker_b_dataset = distribute._RebatchDataset(dataset, batch_sizes=[1, 1, 2]) if with_prefetch: worker_b_dataset = worker_b_dataset.prefetch(1) worker_b_dataset = distribute._AutoShardDataset(worker_b_dataset, 3, 1, num_replicas=3) expected = [[2, 3], [6, 7]] self.assertDatasetProduces(worker_b_dataset, expected) worker_c_dataset = distribute._RebatchDataset(dataset, batch_sizes=[1, 2, 1]) if with_prefetch: worker_c_dataset = worker_c_dataset.prefetch(1) worker_c_dataset = distribute._AutoShardDataset(worker_c_dataset, 3, 2, num_replicas=3) expected = [[], []] self.assertDatasetProduces(worker_c_dataset, expected)
def testEnumerateAutoShardPolicies(self, auto_shard_policy): """Verifies tf.data handles every auto-shard policy with no errors.""" dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) options = options_lib.Options() options.experimental_distribute.auto_shard_policy = auto_shard_policy dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 5, 3) self.getDatasetOutput(dataset, requires_initialization=True)
def testHintShardingInvalidPattern(self): options = options_lib.Options() options.experimental_distribute.auto_shard_policy = ( options_lib.AutoShardPolicy.HINT) dataset = dataset_ops.Dataset.range(100).shard(1, 0) dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 10, 0) self.assertDatasetProduces(dataset, list(range(100)))
def testDirectFilenameTextLineReaderPipeline(self): dataset = core_readers.TextLineDataset(self.test_filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"%d: %d" % (f, r) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected)
def testTFRecordReaderWithDirectFileNames(self): # Using `_TFRecordDataset` creates a raw op rather than wrapping it around # a flat_map automatically. dataset = core_readers._TFRecordDataset(self._filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected)
def testTFRecordReaderWithDirectFileNames(self): # Using `_TFRecordDataset` creates a raw op rather than wrapping it around # a flat_map automatically. dataset = core_readers._TFRecordDataset(self.test_filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected)
def testShardWithRebatch(self): # Tests that RebatchDatasetV2 is a passthrough op. dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._RebatchDataset(dataset, batch_sizes=5) dataset = distribute._AutoShardDataset(dataset, 5, 3) nxt = self.getNext(dataset) self.evaluate(nxt())
def testFlatMapReaderPipeline(self, shuffle): dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=shuffle) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (3, 8) for r in range(0, 10) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
def testFileShardingWithLegacyRebatch(self): # Tests that RebatchDatasetV1 is a passthrough op. self._setUpFiles(num_files=5, num_records_per_file=10) dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [[self._record(3, i)] for i in range(10)] self.assertDatasetProduces(dataset, expected)
def testFileShardingWithRebatch(self): # Tests that RebatchDatasetV2 is a passthrough op. self._setUpFiles(num_files=3, num_records_per_file=5) dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._RebatchDataset(dataset, batch_sizes=[2, 1, 2]) dataset = distribute._AutoShardDataset(dataset, 3, 1) expected = [[self._record(1, 0), self._record(1, 1)], [self._record(1, 2)], [self._record(1, 3), self._record(1, 4)]] self.assertDatasetProduces(dataset, expected)
def testFlatMapReaderPipeline(self, shuffle): dataset = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=shuffle) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (3, 8) for r in range(0, 10) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
def testUnknownOpInPipelineStillShardsAtTheEnd(self): dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.apply(unique.unique()) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected)
def testPrivateThreadpool(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
def testFileShardingWithoutReaderDatasetOp(self): options = options_lib.Options() options.experimental_distribute.auto_shard_policy = ( options_lib.AutoShardPolicy.FILE) dataset = dataset_ops.Dataset.range(1024) dataset = dataset.with_options(options) # We are specifying that we want a file sharding policy, and this pipeline # doesn't start with file reading, so we should error out. with self.assertRaises(errors.NotFoundError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)())
def testAssertCardinality(self): dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = dataset.apply(cardinality.assert_cardinality(42)) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
def testUnknownOpInPipelineStillShardsAtTheEnd(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.apply(unique.unique()) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected)
def testShardInputToInterleave(self): file1 = self._writeFile("f0", [1, 2, 3]) file2 = self._writeFile("f1", [4, 5, 6]) file3 = self._writeFile("f2", [7, 8, 9]) dataset = dataset_ops.Dataset.from_tensor_slices([file1, file2, file3]) dataset = dataset.interleave(core_readers.TFRecordDataset, cycle_length=3) dataset = distribute._AutoShardDataset(dataset, 2, 0) # Sharding by file will interleave files 0 and 2 expected = [str.encode(str(i)) for i in [1, 7, 2, 8, 3, 9]] actual = self.getDatasetOutput(dataset) self.assertEqual(actual, expected)
def testShardInputToInterleaveWithIdentityFunction(self): self.skipTest("Currently fails due to b/238645949") file1 = self._writeFile("f0", [1, 2, 3]) file2 = self._writeFile("f1", [4, 5, 6]) file3 = self._writeFile("f2", [7, 8, 9]) dataset = dataset_ops.Dataset.from_tensor_slices([file1, file2, file3]) dataset = dataset.map(core_readers.TFRecordDataset) dataset = dataset.interleave(lambda x: x, cycle_length=3) dataset = distribute._AutoShardDataset(dataset, 2, 0) # Sharding by file will interleave files 0 and 2 expected = [str.encode(str(i)) for i in [1, 7, 2, 8, 3, 9]] actual = self.getDatasetOutput(dataset) self.assertEqual(actual, expected)
def testTFRecordReaderWithDirectFileNamesAndShapes(self): # Using `_TFRecordDataset` creates a raw op rather than wrapping it around # a flat_map automatically. dataset = core_readers._TFRecordDataset(self._filenames) # BatchDataset contains `output_types` and `output_shapes` dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 2, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 5) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
def testSampleResNetPipeline(self, shuffle): dataset = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=shuffle) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
def testSampleResNetPipeline(self, shuffle): dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=shuffle) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
def testAutoshardPolicyOff(self): options = options_lib.Options() options.experimental_distribute.auto_shard_policy = ( options_lib.AutoShardPolicy.OFF) dataset = core_readers._TFRecordDataset(self._filenames) dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 5, 0) # Should return every record in every file since autosharding is turned off. expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected)
def testTFRecordReaderWithDirectFileNamesAndShapes(self): # Using `_TFRecordDataset` creates a raw op rather than wrapping it around # a flat_map automatically. dataset = core_readers._TFRecordDataset(self.test_filenames) # BatchDataset contains `output_types` and `output_shapes` dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 2, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 5) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
def testPipelineWithMap(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=True) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
def testWorkersGreaterThanNumFilesWithDataSharding(self): options = options_lib.Options() options.experimental_distribute.auto_shard_policy = ( options_lib.AutoShardPolicy.DATA) dataset = core_readers._TFRecordDataset(self._filenames) dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 5, 0) # Should return "Record (0,5) of file (0 --> 9)" since we are sharding by # individual elements, we should be able to get some data from all files. expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected)
def testValidPipelineWithRangeDataset(self, shuffle): dataset = dataset_ops.Dataset.range(self._num_files) dataset = dataset.map(lambda n: string_ops.string_join( # pylint:disable=g-long-lambda [self.get_temp_dir(), string_ops.string_format("/tf_record.{}.txt", [n])])) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
def testValidPipelineWithRangeDataset(self, shuffle): dataset = dataset_ops.Dataset.range(self._num_files) dataset = dataset.map(lambda n: string_ops.string_join( # pylint:disable=g-long-lambda [self.get_temp_dir(), string_ops.string_format("/tf_record.{}.txt", [n])])) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
def testZipReaderPipeline(self): dataset1 = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=False) dataset1 = dataset1.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset2 = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=False) dataset2 = dataset2.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ (b"Record %d of file %d" % (r, f), b"Record %d of file %d" % (r, f)) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProduces(dataset, expected)
def auto_shard_dataset(dataset, num_shards, index): """Shard the input pipeline by sharding the underlying list of files. Args: dataset: A `tf.data.Dataset` instance, typically the result of a bunch of dataset transformations. num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of shards operating in parallel. Same usage as in `tf.data.Dataset.shard`. index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. Same usage as in `tf.data.Dataset.shard`. Returns: A modified `Dataset` obtained by updating the pipeline sharded by the files. The input dataset will be returned if we cannot automatically determine a good way to shard the input dataset. """ if isinstance(dataset, dataset_ops.DatasetV1): return distribute._AutoShardDatasetV1(dataset, num_shards, index) else: return distribute._AutoShardDataset(dataset, num_shards, index)
def testStandardReaderPipeline(self, num_epochs, index, batch_size, parallel_reads): dataset = readers.make_tf_record_dataset( file_pattern=self.test_filenames, num_epochs=num_epochs, batch_size=batch_size, parser_fn=None, num_parallel_reads=parallel_reads, drop_final_batch=True, shuffle=False) dataset = distribute._AutoShardDataset(dataset, 2, index) outputs = self.getNext(dataset) self._verify_records( outputs, batch_size=batch_size, file_index=[i for i in range(index, self._num_records, 2)], num_epochs=num_epochs, interleave_cycle_length=parallel_reads, drop_final_batch=True, use_parser_fn=None) with self.assertRaises(errors.OutOfRangeError): self.evaluate(outputs())
def testNoReaderPipelines(self): dataset = dataset_ops.Dataset.range(1024) dataset = distribute._AutoShardDataset(dataset, 2, 0) self.assertDatasetProduces(dataset, [i for i in range(1024) if i % 2 == 0])
def testNoReaderPipelines(self): dataset = dataset_ops.Dataset.range(1024) with self.assertRaises(errors.NotFoundError): dataset = distribute._AutoShardDataset(dataset, 2, 0) self.evaluate(self.getNext(dataset)())
def testShardOutOfRangeEmptyDataset(self): dataset = dataset_ops.Dataset.range(0) with self.assertRaises(errors.OutOfRangeError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)())
def testShardOutOfRange(self): dataset = dataset_ops.Dataset.range(5) with self.assertRaises(errors.InvalidArgumentError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)())