class MultiDeviceIteratorTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): super(MultiDeviceIteratorTest, self).setUp() self._devices = self.configureDevicesForMultiDeviceTest(3) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(num_inits=[0, 1, 42]))) def testInitOnly(self, num_inits): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]]) for _ in range(num_inits): self.evaluate(multi_device_iterator.initializer) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(max_buffer_size=[0, 1, 10], prefetch_buffer_size=[0, 1, 10]))) def testBasic(self, prefetch_buffer_size, max_buffer_size): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]], max_buffer_size=max_buffer_size, prefetch_buffer_size=prefetch_buffer_size) self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.default_test_combinations()) def testOneOnSameDevice(self): dataset = dataset_ops.Dataset.range(12) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[0], self._devices[1], self._devices[2]]) self.evaluate(multi_device_iterator.initializer) for i in range(0, 12, 3): elem_on_0, elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_0)) self.assertEqual(i + 1, self.evaluate(elem_on_1)) self.assertEqual(i + 2, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_0, elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_0) self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.default_test_combinations()) def testRepeatDevices(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[1]]) self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elements = multi_device_iterator.get_next() elem_on_1, elem_on_2 = elements self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elements = multi_device_iterator.get_next() elem_on_1, elem_on_2 = elements self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.default_test_combinations()) def testNotFullyDivisible(self): dataset = dataset_ops.Dataset.range(9) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]]) self.evaluate(multi_device_iterator.initializer) for i in range(0, 8, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) elem_on_1 = multi_device_iterator.get_next(self._devices[1]) self.assertEqual(8, self.evaluate(elem_on_1)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.default_test_combinations()) def testGetNextAsOptional(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]]) self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional() has_elem_1, get_elem_1 = self.evaluate( [elem_on_1.has_value(), elem_on_1.get_value()]) has_elem_2, get_elem_2 = self.evaluate( [elem_on_2.has_value(), elem_on_2.get_value()]) self.assertTrue(has_elem_1) self.assertEqual(i, get_elem_1) self.assertTrue(has_elem_2) self.assertEqual(i + 1, get_elem_2) elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional() has_elem_1 = elem_on_1.has_value() has_elem_2 = elem_on_2.has_value() self.assertFalse(self.evaluate(has_elem_1)) self.assertFalse(self.evaluate(has_elem_2)) with self.assertRaises(errors.InvalidArgumentError): elem_1 = elem_on_1.get_value() self.evaluate(elem_1) with self.assertRaises(errors.InvalidArgumentError): elem_2 = elem_on_2.get_value() self.evaluate(elem_2) @combinations.generate(test_base.default_test_combinations()) def testUneven(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]], max_buffer_size=4) self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1 = multi_device_iterator.get_next(self._devices[1]) self.assertEqual(i, self.evaluate(elem_on_1)) for i in range(0, 10, 2): elem_on_2 = multi_device_iterator.get_next(self._devices[2]) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.graph_only_combinations()) def testMultipleInitializationsGraph(self): dataset1 = dataset_ops.Dataset.range(1000) dataset2 = dataset_ops.Dataset.range(1000) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]], prefetch_buffer_size=4) elem_on_1, elem_on_2 = multi_device_iterator.get_next() for _ in range(5): self.evaluate(multi_device_iterator.initializer) self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2])) @combinations.generate(test_base.eager_only_combinations()) def testMultipleInitializationsEager(self): dataset1 = dataset_ops.Dataset.range(1000) dataset2 = dataset_ops.Dataset.range(1000) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) for _ in range(5): multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]], prefetch_buffer_size=4) self.evaluate(multi_device_iterator.initializer) elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2])) @combinations.generate(test_base.default_test_combinations()) def testOptimization(self): dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"])) dataset = dataset.skip(0) # this should be optimized away dataset = dataset.cache() options = options_lib.Options() options.experimental_optimization.noop_elimination = True dataset = dataset.with_options(options) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]]) self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2)
class AutoShardDatasetTest( reader_dataset_ops_test_base.TFRecordDatasetTestBase, parameterized.TestCase): def setUp(self): super(AutoShardDatasetTest, self).setUp() self._num_files = 10 self._num_records = 10 self.test_filenames = self._createFiles() def getAllDatasetElements(self, dataset): actual = [] next_fn = self.getNext(dataset) while True: try: actual.append(self.evaluate(next_fn())) except errors.OutOfRangeError: break return actual def assertDatasetProducesWithShuffle(self, dataset, expected, batch, num_examples, shuffle): if shuffle: actual = [] next_fn = self.getNext(dataset) for _ in range(num_examples): elem = self.evaluate(next_fn()) if isinstance(elem, tuple): actual.extend(elem) else: actual.extend(elem.tolist()) self.assertCountEqual(actual, expected) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_fn()) else: self.assertDatasetProduces(dataset, list(chunk(expected, batch))) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testFlatMapReaderPipeline(self, shuffle): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=shuffle) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (3, 8) for r in range(0, 10) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate(test_base.default_test_combinations()) def testZipReaderPipeline(self): dataset1 = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset1 = dataset1.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset2 = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset2 = dataset2.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ (b"Record %d of file %d" % (r, f), b"Record %d of file %d" % (r, f) ) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProduces(dataset, expected) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testConcatenateReaderPipeline(self, shuffle): dataset1 = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=shuffle) dataset1 = dataset1.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset1 = dataset1.batch(5) dataset2 = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=shuffle) dataset2 = dataset2.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset2 = dataset2.batch(5) dataset = dataset1.concatenate(dataset2) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] expected += expected self.assertDatasetProducesWithShuffle(dataset, expected, 5, 8, shuffle) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testPipelineWithMap(self, shuffle): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate(test_base.default_test_combinations()) def testDirectFilenameTFRecordReaderPipeline(self): dataset = core_readers.TFRecordDataset(self.test_filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testValidPipelineWithRangeDataset(self, shuffle): dataset = dataset_ops.Dataset.range(self._num_files) dataset = dataset.map(lambda n: string_ops.string_join( # pylint:disable=g-long-lambda [ self.get_temp_dir(), string_ops.string_format("/tf_record.{}.txt", [n]) ])) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(params=[(1, 0, 10, 10), (2, 1, 20, 5), (10, 1, 1, 10)])) ) def testStandardReaderPipeline(self, params): num_epochs, index, batch_size, parallel_reads = params dataset = readers.make_tf_record_dataset( file_pattern=self.test_filenames, num_epochs=num_epochs, batch_size=batch_size, parser_fn=None, num_parallel_reads=parallel_reads, drop_final_batch=True, shuffle=False) dataset = distribute._AutoShardDataset(dataset, 2, index) outputs = self.getNext(dataset) self._verify_records( outputs, batch_size=batch_size, file_index=[i for i in range(index, self._num_records, 2)], num_epochs=num_epochs, interleave_cycle_length=parallel_reads, drop_final_batch=True, use_parser_fn=None) with self.assertRaises(errors.OutOfRangeError): self.evaluate(outputs()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testSampleResNetPipeline(self, shuffle): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=shuffle) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.times( combinations.combine(sharding_policy=[ distribute_options.AutoShardPolicy.DATA, distribute_options.AutoShardPolicy.FILE ]), combinations.combine(shuffle=[True, False])))) def testReplicateAndShardProduceDisjointData(self, shuffle, sharding_policy): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=shuffle) dataset = dataset.flat_map(core_readers.TFRecordDataset) graph_def = dataset._as_serialized_graph( strip_device_assignment=True, external_state_policy=dataset.options( ).experimental_external_state_policy) options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = sharding_policy ds1 = distribute._RemoteDataset(graph_def, "/device:CPU:0", dataset.element_spec) ds2 = distribute._RemoteDataset(graph_def, "/device:CPU:0", dataset.element_spec) ds1 = ds1.with_options(options) ds2 = ds2.with_options(options) ds1 = distribute._AutoShardDataset(ds1, 2, 0) ds2 = distribute._AutoShardDataset(ds2, 2, 1) elems1 = set(self.getAllDatasetElements(ds1)) elems2 = set(self.getAllDatasetElements(ds2)) self.assertEmpty(elems1.intersection(elems2)) @combinations.generate(test_base.default_test_combinations()) def testWorkersGreaterThanNumFilesWithDataSharding(self): options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = ( distribute_options.AutoShardPolicy.DATA) dataset = core_readers._TFRecordDataset(self.test_filenames) dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 5, 0) # Should return "Record (0,5) of file (0 --> 9)" since we are sharding by # individual elements, we should be able to get some data from all files. expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testAutoshardPolicyOff(self): options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = ( distribute_options.AutoShardPolicy.OFF) dataset = core_readers._TFRecordDataset(self.test_filenames) dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 5, 0) # Should return every record in every file since autosharding is turned off. expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testFileShardingWithoutReaderDatasetOp(self): options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = ( distribute_options.AutoShardPolicy.FILE) dataset = dataset_ops.Dataset.range(1024) dataset = dataset.with_options(options) # We are specifying that we want a file sharding policy, and this pipeline # doesn't start with file reading, so we should error out. with self.assertRaises(errors.NotFoundError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)()) @combinations.generate(test_base.default_test_combinations()) def testWorkersGreaterThanNumFiles(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 500, 499) self.assertDatasetProduces(dataset, []) @combinations.generate(test_base.default_test_combinations()) def testTFRecordReaderWithDirectFileNames(self): # Using `_TFRecordDataset` creates a raw op rather than wrapping it around # a flat_map automatically. dataset = core_readers._TFRecordDataset(self.test_filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testTFRecordReaderWithDirectFileNamesAndShapes(self): # Using `_TFRecordDataset` creates a raw op rather than wrapping it around # a flat_map automatically. dataset = core_readers._TFRecordDataset(self.test_filenames) # BatchDataset contains `output_types` and `output_shapes` dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 2, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 5) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5))) @combinations.generate(test_base.default_test_combinations()) def testShardOutOfRange(self): dataset = dataset_ops.Dataset.range(5) with self.assertRaises(errors.InvalidArgumentError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)()) @combinations.generate(test_base.default_test_combinations()) def testShardOutOfRangeEmptyDataset(self): dataset = dataset_ops.Dataset.range(0) with self.assertRaises(errors.OutOfRangeError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)()) @combinations.generate(test_base.default_test_combinations()) def testShardWithRebatch(self): # Tests that Rebatch is a passthrough op. dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( optimization.assert_next( ["Shard", "FlatMap", "BatchV2", "Rebatch"])) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._RebatchDataset(dataset, num_replicas=1) dataset = distribute._AutoShardDataset(dataset, 5, 3) nxt = self.getNext(dataset) self.evaluate(nxt()) @combinations.generate(test_base.default_test_combinations()) def testNoReaderPipelines(self): dataset = dataset_ops.Dataset.range(1024) dataset = distribute._AutoShardDataset(dataset, 2, 0) self.assertDatasetProduces(dataset, [i for i in range(1024) if i % 2 == 0]) @combinations.generate(test_base.default_test_combinations()) def testUnknownOpInPipelineStillShardsAtTheEnd(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.apply(unique.unique()) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testInvalidWorkerIndex(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) with self.assertRaises(errors.InvalidArgumentError): dataset = distribute._AutoShardDataset(dataset, 2, 2) self.evaluate(self.getNext(dataset)())
class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(count=[32, 34], padded_shapes=[[None], [25]], drop_remainder=[True, False]))) def testPaddedBatchDataset(self, count, padded_shapes, drop_remainder): seq_lens = np.random.randint(20, size=(count, )).astype(np.int32) batch_size = 4 dataset = dataset_ops.Dataset.from_tensor_slices(seq_lens).map( lambda x: array_ops.fill([x], x)).padded_batch( batch_size=batch_size, drop_remainder=drop_remainder, padded_shapes=padded_shapes) num_full_batches = len(seq_lens) // batch_size get_next = self.getNext(dataset) for i in range(num_full_batches): result = self.evaluate(get_next()) padded_len = padded_shapes[0] if padded_len is None or padded_len == -1: padded_len = np.max(result) if result.size > 0 else 0 self.assertEqual((batch_size, padded_len), result.shape) for j in range(batch_size): seq_len = seq_lens[(i * batch_size) + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) if not drop_remainder and len(seq_lens) % batch_size > 0: result = self.evaluate(get_next()) padded_len = padded_shapes[0] if padded_len is None or padded_len == -1: padded_len = np.max(result) if result.size > 0 else 0 self.assertEqual((len(seq_lens) % batch_size, padded_len), result.shape) for j in range(len(seq_lens) % batch_size): seq_len = seq_lens[num_full_batches * batch_size + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchShortPadding(self): dataset = (dataset_ops.Dataset.from_tensor_slices( [6, 5, 5, 5, 5]).map(lambda x: array_ops.fill([x], x)).padded_batch( batch_size=4, padded_shapes=[5])) self.assertDatasetProduces(dataset, expected_error=(errors.DataLossError, '')) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchEmptyTensors(self): dataset = (dataset_ops.Dataset.from_tensor_slices( [0, 0, 0, 0]).map(lambda x: array_ops.fill([x], x)).padded_batch( batch_size=4, padded_shapes=[-1])) self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]]) @combinations.generate(test_base.default_test_combinations()) def testDefaultPaddedShapes(self): def fill(x): return array_ops.fill([x], x) dataset = (dataset_ops.Dataset.from_tensor_slices( [1, 2, 3, 4]).map(fill).padded_batch(batch_size=2)) self.assertDatasetProduces(dataset, expected_output=[[[1, 0], [2, 2]], [[3, 3, 3, 0], [4, 4, 4, 4]]]) @combinations.generate(test_base.default_test_combinations()) def testNestedDefaultPaddedShapes(self): def fill_tuple(x): return (x, array_ops.fill([x], x)) dataset = (dataset_ops.Dataset.from_tensor_slices( [1, 2, 3, 4]).map(fill_tuple).padded_batch(batch_size=2)) self.assertDatasetProduces(dataset, expected_output=[([1, 2], [[1, 0], [2, 2]]), ([3, 4], [[3, 3, 3, 0], [4, 4, 4, 4]])]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(padding_values=[(-1, '<end>', { 'structure': '' }), (-1, '<end>', None)]))) def testPaddedBatchDatasetNonDefaultPadding(self, padding_values): def fill_tuple(x): filled = array_ops.fill([x], x) return (filled, string_ops.as_string(filled), { 'structure': string_ops.as_string(filled) }) random_seq_lens = np.random.randint(20, size=(32, )).astype(np.int32) dataset = (dataset_ops.Dataset.from_tensor_slices(random_seq_lens).map( fill_tuple).padded_batch(4, padded_shapes=([-1], [-1], { 'structure': [-1] }), padding_values=padding_values)) get_next = self.getNext(dataset) for i in range(8): result = self.evaluate(get_next()) padded_len = np.max(result[0]) self.assertEqual((4, padded_len), result[0].shape) self.assertEqual((4, padded_len), result[1].shape) self.assertEqual((4, padded_len), result[2]['structure'].shape) for j in range(4): seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[0][j, seq_len:], [-1] * (padded_len - seq_len)) self.assertAllEqual(result[1][j, :seq_len], [compat.as_bytes(str(seq_len))] * seq_len) self.assertAllEqual(result[1][j, seq_len:], [b'<end>'] * (padded_len - seq_len)) self.assertAllEqual(result[2]['structure'][j, :seq_len], [compat.as_bytes(str(seq_len))] * seq_len) self.assertAllEqual(result[2]['structure'][j, seq_len:], [b''] * (padded_len - seq_len)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchDatasetUnicode(self): # See GitHub issue 16149 def generator(): data = [[u'Простой', u'тест', u'юникода'], [u'никогда', u'не', u'бывает', u'простым']] for seq in data: yield seq, [0, 1, 2, 3] dataset = dataset_ops.Dataset.from_generator( generator, (dtypes.string, dtypes.int32), (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None ]))) padded_dataset = dataset.padded_batch(2, padded_shapes=([None], [None]), padding_values=('', 0)) next_element = self.getNext(padded_dataset) self.evaluate(next_element()) @combinations.generate(test_base.graph_only_combinations()) def testPaddedBatchDatasetShapeSpecifications(self): int_placeholder = array_ops.placeholder(dtypes.int32) float_placeholder = array_ops.placeholder(dtypes.float32) string_placeholder = array_ops.placeholder(dtypes.string) input_dataset = dataset_ops.Dataset.from_tensors( (int_placeholder, float_placeholder, string_placeholder)) # Test different ways of specifying the `padded_shapes` argument. dynamic_padding_from_tensor_shapes = input_dataset.padded_batch( 32, padded_shapes=(tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None, None]), tensor_shape.TensorShape([37]))) dynamic_padding_from_lists = input_dataset.padded_batch( 32, padded_shapes=([None], [None, None], [37])) dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch( 32, padded_shapes=([-1], [-1, -1], [37])) dynamic_padding_from_tensors = input_dataset.padded_batch( 32, padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64), constant_op.constant([-1, -1], dtype=dtypes.int64), constant_op.constant([37], dtype=dtypes.int64))) for dataset in [ dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists, dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors ]: dataset_output_shapes = dataset_ops.get_legacy_output_shapes( dataset) self.assertEqual([None, None], dataset_output_shapes[0].as_list()) self.assertEqual([None, None, None], dataset_output_shapes[1].as_list()) self.assertEqual([None, 37], dataset_output_shapes[2].as_list()) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchSparseError(self): st = sparse_tensor.SparseTensorValue(indices=[[0, 0]], values=([42]), dense_shape=[1, 1]) with self.assertRaises(TypeError): _ = dataset_ops.Dataset.from_tensors(st).repeat(10).padded_batch( 10) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchRaggedError(self): rt = ragged_tensor_value.RaggedTensorValue( np.array([0, 42]), np.array([0, 2], dtype=np.int64)) with self.assertRaises(TypeError): _ = dataset_ops.Dataset.from_tensors(rt).repeat(10).padded_batch( 10) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchShapeErrorWrongRank(self): with self.assertRaisesRegex( ValueError, r'The padded shape \(1,\) is not compatible with the ' r'corresponding input component shape \(\).'): _ = dataset_ops.Dataset.range(10).padded_batch(5, padded_shapes=[1]) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchShapeErrorTooSmall(self): with self.assertRaisesRegex( ValueError, r'The padded shape \(1,\) is not compatible with the ' r'corresponding input component shape \(3,\).'): _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch( 5, padded_shapes=[1]) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchShapeErrorShapeNotRank1(self): with self.assertRaisesRegex( ValueError, r'Padded shape .* must be a 1-D tensor ' r'of tf.int64 values, but its shape was \(2, 2\).'): _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch( 5, padded_shapes=[[1, 1], [1, 1]]) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchShapeErrorShapeNotInt(self): with self.assertRaisesRegex( TypeError, r'Padded shape .* must be a 1-D tensor ' r'of tf.int64 values, but its element type was float32.'): _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch( 5, padded_shapes=constant_op.constant([1.5, 2., 3.])) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchShapeErrorWrongRankFromTensor(self): with self.assertRaisesRegex( ValueError, r'The padded shape \(1,\) is not compatible with the ' r'corresponding input component shape \(\).'): shape_as_tensor = constant_op.constant([1], dtype=dtypes.int64) _ = dataset_ops.Dataset.range(10).padded_batch( 5, padded_shapes=shape_as_tensor) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchShapeErrorDefaultShapeWithUnknownRank(self): with self.assertRaisesRegex(ValueError, r'`padded_shapes`.*unknown rank'): ds = dataset_ops.Dataset.from_generator(lambda: iter([1, 2, 3]), output_types=dtypes.int32) ds.padded_batch(2) @combinations.generate(test_base.graph_only_combinations()) def testPaddedBatchShapeErrorPlaceholder(self): with self.assertRaisesRegex( ValueError, r'The padded shape \((\?|None), (\?|None)\) is not compatible with the ' r'corresponding input component shape \(\).'): shape_as_tensor = array_ops.placeholder(dtypes.int64, shape=[2]) _ = dataset_ops.Dataset.range(10).padded_batch( 5, padded_shapes=shape_as_tensor) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatchBfloat16(self): ds = dataset_ops.Dataset.range(5) ds = ds.map(lambda x: math_ops.cast(x, dtypes.bfloat16)) ds = ds.padded_batch(10) self.assertDatasetProduces(ds, expected_output=[[0.0, 1.0, 2.0, 3.0, 4.0]]) @combinations.generate(test_base.default_test_combinations()) def testDefaultPaddedValueShapes(self): def fill(x): return array_ops.fill([x], x) dataset = dataset_ops.Dataset.zip( (dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]).map(fill), dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]).map(fill))) dataset = dataset.padded_batch(batch_size=2, padding_values=-1) self.assertDatasetProduces(dataset, expected_output=[([[1, -1], [2, 2]], [[1, -1], [2, 2]]), ([[3, 3, 3, -1], [4, 4, 4, 4]], [[3, 3, 3, -1], [4, 4, 4, 4]])])
class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( num_parallel_calls=[None, 1, 2], num_parallel_batches=None) + combinations.combine( num_parallel_calls=None, num_parallel_batches=10))) def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) def dataset_fn(batch_size, count): dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat( count).apply( batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, num_parallel_calls=num_parallel_calls, num_parallel_batches=num_parallel_batches)) return dataset # Batch of a finite input, where the batch_size divides the # total number of elements. dataset = dataset_fn(14, 28) get_next = self.getNext(dataset) self.assertEqual( [[None] + list(c.shape[1:]) for c in components], [shape.as_list() for shape in dataset_ops.get_legacy_output_shapes(dataset)]) num_batches = (28 * 7) // 14 for i in range(num_batches): result = self.evaluate(get_next()) for component, result_component in zip(components, result): for j in range(14): self.assertAllEqual(component[(i * 14 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # Batch of a finite input, where the batch_size does not # divide the total number of elements. get_next = self.getNext(dataset_fn(8, 14)) # We expect (num_batches - 1) full-sized batches. num_batches = int(math.ceil((14 * 7) / 8)) for i in range(num_batches - 1): result = self.evaluate(get_next()) for component, result_component in zip(components, result): for j in range(8): self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) result = self.evaluate(get_next()) for component, result_component in zip(components, result): for j in range((14 * 7) % 8): self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # Batch of an empty input should fail straight away. self.assertDatasetProduces(dataset_fn(8, 0), expected_output=[]) # Empty batch should be an initialization time error. with self.assertRaises(errors.InvalidArgumentError): self.assertDatasetProduces(dataset_fn(0, 14), expected_output=[]) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testMapAndBatchPartialBatch(self, drop_remainder): dataset = ( dataset_ops.Dataset.range(10).apply( batching.map_and_batch( lambda x: array_ops.reshape(x * x, [1]), batch_size=4, drop_remainder=drop_remainder))) if drop_remainder: self.assertEqual( [4, 1], dataset_ops.get_legacy_output_shapes(dataset).as_list()) else: self.assertEqual( [None, 1], dataset_ops.get_legacy_output_shapes(dataset).as_list()) expected_output = [[[0], [1], [4], [9]], [[16], [25], [36], [49]]] if not drop_remainder: expected_output.append([[64], [81]]) self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchYieldsPartialBatch(self): dataset = ( dataset_ops.Dataset.range(10).apply( batching.map_and_batch(lambda x: array_ops.reshape(x * x, [1]), 4))) self.assertEqual( [None, 1], dataset_ops.get_legacy_output_shapes(dataset).as_list()) expected_output = [[[0], [1], [4], [9]], [[16], [25], [36], [49]], [[64], [81]]] self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchParallelGetNext(self): dataset = dataset_ops.Dataset.range(50000).apply( batching.map_and_batch(lambda x: x, batch_size=100)) if context.executing_eagerly(): iterator = iter(dataset) get_next = iterator._next_internal # pylint: disable=protected-access else: iterator = dataset_ops.make_one_shot_iterator(dataset) get_next = iterator.get_next elements = [] for _ in range(100): elements.append(get_next) for i in range(5): got = self.evaluate([element() for element in elements]) got.sort(key=lambda x: x[0]) expected = [] for j in range(100): expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100)) self.assertAllEqual(got, expected) with self.assertRaises(errors.OutOfRangeError): self.evaluate([element() for element in elements]) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchParallelGetNextDropRemainder(self): dataset = dataset_ops.Dataset.range(49999).apply( batching.map_and_batch( lambda x: x, batch_size=100, drop_remainder=True)) if context.executing_eagerly(): iterator = iter(dataset) get_next = iterator._next_internal # pylint: disable=protected-access else: iterator = dataset_ops.make_one_shot_iterator(dataset) get_next = iterator.get_next elements = [] for _ in range(100): elements.append(get_next) for i in range(4): got = self.evaluate([element() for element in elements]) got.sort(key=lambda x: x[0]) expected = [] for j in range(100): expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100)) self.assertAllEqual(got, expected) with self.assertRaises(errors.OutOfRangeError): self.evaluate([element() for element in elements]) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchSparse(self): def _sparse(i): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) dataset = dataset_ops.Dataset.range(10).apply( batching.map_and_batch(_sparse, 5)) self.assertDatasetProduces( dataset, expected_output=[ sparse_tensor.SparseTensorValue( indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], dense_shape=[5, 1]) for i in range(2) ]) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchFails(self): """Test a dataset that maps a TF function across its input elements.""" with self.assertRaisesRegex(errors.InvalidArgumentError, "oops"): dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) dataset = dataset.apply(batching.map_and_batch(lambda x: x, 14)) get_next = self.getNext(dataset, requires_initialization=True) self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" def generator(): yield [1] yield [2] yield [3] yield [[4, 5, 6]] dataset = dataset_ops.Dataset.from_generator( generator, output_types=dtypes.int32) batch_size = 4 dataset = dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) self.assertDatasetProduces( dataset, expected_error=(errors.InvalidArgumentError, "number of elements does not match")) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchImplicitDispose(self): # Tests whether a map and batch dataset will be cleaned up correctly when # the pipeline does not run it until exhaustion. # The pipeline is TensorSliceDataset -> RepeatDataset(1000) -> # MapAndBatchDataset(f=square_3, batch_size=100). components = (np.arange(1000), np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], np.array(37.0) * np.arange(1000)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat( 1000).apply(batching.map_and_batch(_map_fn, batch_size=100)) dataset = dataset.prefetch(5) get_next = self.getNext(dataset) for _ in range(3): self.evaluate(get_next()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(threshold=[0, 5, 10, 90, 95, 99])) ) def testMapAndBatchMapError(self, threshold): def raising_py_fn(i): if i >= threshold: raise StopIteration() else: return i dataset = dataset_ops.Dataset.range(100).apply( batching.map_and_batch( lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64), batch_size=10)) get_next = self.getNext(dataset) for i in range(threshold // 10): self.assertAllEqual([i * 10 + j for j in range(10)], self.evaluate(get_next())) for i in range(threshold // 10, 10): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(element=False, dtype=dtypes.bool) + combinations.combine( element=-42, dtype=[dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]) + combinations.combine(element=42, dtype=[dtypes.uint8, dtypes.uint16]) + combinations.combine( element=42.0, dtype=[dtypes.float16, dtypes.float32, dtypes.float64]) + combinations.combine(element=b"hello", dtype=[dtypes.string]))) def testMapAndBatchTypes(self, element, dtype): def gen(): yield element dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply( batching.map_and_batch(lambda x: x, batch_size=10)) get_next = self.getNext(dataset) for _ in range(10): self.assertAllEqual([element for _ in range(10)], self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testShortCircuitIdentity(self): map_fn = lambda x: x dataset = self.structuredDataset(None).repeat().apply( batching.map_and_batch(map_fn, batch_size=10)) get_next = self.getNext(dataset) expected = map_fn(self.evaluate(self.structuredElement(None, shape=[10]))) self.assertAllEqual(expected, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testShortCircuitReplicate(self): map_fn = lambda x: (x, x) dataset = self.structuredDataset(None).repeat().apply( batching.map_and_batch(map_fn, batch_size=10)) get_next = self.getNext(dataset) expected = map_fn(self.evaluate(self.structuredElement(None, shape=[10]))) self.assertAllEqual(expected, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testShortCircuitSwap(self): map_fn = lambda x, y: (y, x) dataset = self.structuredDataset( (None, None)).repeat().apply(batching.map_and_batch(map_fn, batch_size=10)) get_next = self.getNext(dataset) expected = map_fn( *self.evaluate(self.structuredElement((None, None), shape=[10]))) self.assertAllEqual(expected, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testShortCircuitProject(self): map_fn = lambda x, y: x dataset = self.structuredDataset( (None, None)).repeat().apply(batching.map_and_batch(map_fn, batch_size=10)) get_next = self.getNext(dataset) expected = map_fn( *self.evaluate(self.structuredElement((None, None), shape=[10]))) self.assertAllEqual(expected, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testShortCircuitCapturedInput(self): captured_t = variables.Variable(42) dataset = self.structuredDataset(None).repeat().apply( batching.map_and_batch(lambda x: captured_t, batch_size=10)) self.evaluate(variables.global_variables_initializer()) get_next = self.getNext(dataset, requires_initialization=True) self.assertAllEqual([42] * 10, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchControlFlow(self): def map_fn(x): previous_control_flow_v2_value = control_flow_util.ENABLE_CONTROL_FLOW_V2 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x) control_flow_util.ENABLE_CONTROL_FLOW_V2 = previous_control_flow_v2_value return return_value dataset = dataset_ops.Dataset.range(100).apply( batching.map_and_batch(map_fn, batch_size=10)) get_next = self.getNext(dataset) for i in range(10): if i < 5: self.assertAllEqual([i * 10 + j + 1 for j in range(10)], self.evaluate(get_next())) else: self.assertAllEqual( [((i * 10) + j) * ((i * 10) + j) for j in range(10)], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.eager_only_combinations()) def testCheckpointLargeBatches(self): # Batches of size 512M dataset = dataset_ops.Dataset.from_tensors( array_ops.ones((64, 1024, 1024), dtype=dtypes.float32)).repeat() dataset = dataset.map(lambda x: x+1, num_parallel_calls=5) dataset = dataset.batch(2) iterator = iter(dataset) next(iterator) # request an element to fill the buffer ckpt = trackable_utils.Checkpoint(iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, self.get_temp_dir(), max_to_keep=1) manager.save()
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def create_cluster(self, num_workers): """Creates a cluster of tf.data service servers. Args: num_workers: The number of workers in the cluster. Returns: The address of the master. """ self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._servers = [] for _ in range(num_workers): self._servers.append( server_lib.WorkerServer(port=0, master_address=self._master._address, protocol=PROTOCOL)) return self._master._address @combinations.generate(test_base.eager_only_combinations()) def testDistributeBasic(self): num_elements = 10 master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, master_address) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDifferentShuffleOrders(self): random_seed.set_random_seed(None) num_elements = 100 master_address = self.create_cluster(2) ds = dataset_ops.Dataset.range(num_elements) ds = ds.shuffle(num_elements) ds = _make_distributed_dataset(ds, master_address) output = [elem.numpy() for elem in ds] # The output will be two sequences of range(num_elements) # non-deterministically interleaved together. If the orders of the elements # were the same, first_order and second_order computed below will be equal. first_order = {} second_order = {} for element in output: if element in first_order: second_order[element] = len(second_order) else: first_order[element] = len(first_order) self.assertNotEqual(first_order, second_order) @combinations.generate(test_base.eager_only_combinations()) def testMultipleEpochs(self): num_elements = 3 master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, master_address) for _ in range(10): self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds]) @combinations.generate(test_base.eager_only_combinations()) def testRepeatedDataset(self): num_elements = 10 num_repetitions = 5 master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, master_address) ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_repetitions * list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testConcurrentEpoch(self): num_elements = 10 num_datasets = 3 master_address = self.create_cluster(1) iterators = [] results = [] for _ in range(num_datasets): ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, master_address) iterators.append(iter(ds)) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = next(iterators[dataset_ind]).numpy() results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testSharedEpoch(self): self.skipTest("Not yet implemented") num_elements = 10 num_iterators = 3 master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, master_address) result = [] iterators = [] for _ in range(num_iterators): iterators.append(iter(ds)) # Alternate reading between the iterators. for _ in range(2): for it in iterators: result.append(next(it).numpy()) # Drain the rest of the elements. for it in iterators: for elem in it: result.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testMultiWorker(self): num_workers = 3 num_elements = 10 master_address = self.create_cluster(num_workers) ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, master_address) results = [elem.numpy() for elem in ds] self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testAddWorkerMidJob(self): self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._master._address) iterator = iter(ds) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): results.append(next(iterator).numpy()) self._new_worker = server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL) # Wait for the new worker to register with the master. while self._master._num_workers() < 2: time.sleep(10 / 1000) # 10ms for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(2 * list(range(num_elements)), results) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(use_same_port=[True, False]))) def testRestartWorker(self, use_same_port): self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( port=0, master_address=self._master._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._master._address) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. port = 0 if use_same_port: port = int(self._worker._address.split(":")[1]) self._worker._stop() self._new_worker = server_lib.WorkerServer( port=port, master_address=self._master._address, protocol=PROTOCOL) # There may have been some elements prefetched from the first worker # before it was stopped. while True: val = next(iterator).numpy() if val == 0: break # The dataset starts over now that we read from the new worker. # TODO(b/157086991): Iterate until end of sequence when we support # detecting lost workers. for i in range(1, num_elements // 2): val = next(iterator).numpy() self.assertEqual(i, val) @combinations.generate(test_base.eager_only_combinations()) def testMaxOutstandingRequests(self): num_elements = 10 num_workers = 3 address = self.create_cluster(num_workers) ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply( data_service_ops._distribute("parallel_epochs", "{0}://{1}".format(PROTOCOL, address), max_outstanding_requests=1, task_refresh_interval_hint_ms=20)) self.assertCountEqual(num_workers * list(range(num_elements)), self.getDatasetOutput(ds)) @combinations.generate(test_base.eager_only_combinations()) def testInsideFunction(self): num_workers = 3 num_elements = 10 master_address = self.create_cluster(num_workers) @def_function.function def f(): ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, master_address) result = tensor_array_ops.TensorArray(dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in ds: result = result.write(i, elem) i += 1 return result.stack() result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobName(self): num_elements = 100 master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") iter1 = iter(ds1) iter2 = iter(ds2) results = [] for _ in range(num_elements // 5): results.append(next(iter1).numpy()) results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDifferentJobNames(self): num_elements = 10 master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name1") ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): num_elements = 10 master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) # iteration 2 self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds1, []) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameRepeat(self): num_elements = 100 num_repetitions = 3 master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds1 = ds1.repeat(num_repetitions) ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] iter1 = iter(ds1) iter2 = iter(ds2) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter1).numpy()) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) master_address = self.create_cluster(1) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10) opts = dataset_ops.Options() opts.experimental_deterministic = False ds = ds.with_options(opts) ds = _make_distributed_dataset(ds, master_address) return ds self.checkDeterminism(dataset_fn=dataset_fn, expect_determinism=False, expected_elements=elements) def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = dataset_ops.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) master_address = self.create_cluster(3) ds = _make_distributed_dataset(ds, master_address) next(iter(ds)) @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(external_state_policy=[ distribute_options.ExternalStatePolicy.IGNORE, distribute_options.ExternalStatePolicy.WARN ]))) def testStatefulNoError(self, external_state_policy): self.run_stateful(external_state_policy) @combinations.generate(test_base.eager_only_combinations()) def testStatefulError(self): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(distribute_options.ExternalStatePolicy.FAIL) @combinations.generate(test_base.eager_only_combinations()) def testDistributeFromInterleave(self): master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(2) def interleave_fn(_): ds = dataset_ops.Dataset.range(2) _make_distributed_dataset(ds, master_address) return ds with self.assertRaisesRegex( errors.InvalidArgumentError, r"The `.distribute\(...\)` dataset " "transformation is not supported within tf.data functions"): ds = ds.interleave(interleave_fn, cycle_length=2) @combinations.generate(test_base.eager_only_combinations()) def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "service must be a string"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=1)) @combinations.generate(test_base.eager_only_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "service must not be empty"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="")) @combinations.generate(test_base.eager_only_combinations()) def testDistributeInvalidProcessingMode(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "invalid is not a valid processing mode"): ds = ds.apply( data_service_ops.distribute(processing_mode="invalid", service="grpc://localhost:5000"))
class LocalTaskGarbageCollectTest(data_service_test_base.TestBase, parameterized.TestCase): """Tests garbage collecting unused local worker tasks. The user typically creates an iterator in each epoch. This should delete the previous iterator and releases the resources of it. """ @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testMultipleEpochs(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_epochs, num_steps = 5, 5 dataset = self._make_distributed_infinite_range_dataset(cluster) for _ in range(num_epochs): # For each iteration, the previous iterator is garbage collected. get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testMultipleEpochsSharedJob(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_epochs, num_steps = 5, 5 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") for _ in range(num_epochs): # For each iteration, the previous iterator is garbage collected. get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_remote_workers=[0, 3], job_name=[None, "shared_job_name"]))) def testRepeatDistributedDataset(self, num_remote_workers, job_name): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) dataset = self.make_distributed_range_dataset(10, cluster, job_name=job_name, target_workers="LOCAL") dataset = dataset.repeat(3) self.assertDatasetProduces(dataset, list(range(10)) * 3) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testReadFromDeletedTask(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_steps = 10 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) # Re-creating the dataset resets the iterator index, so the second iterator # reads from the same task as the first, which has been deleted. dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") get_next = self.getNext(dataset) with self.assertRaisesRegex(errors.FailedPreconditionError, "which has been deleted."): _ = self.evaluate(get_next()) @combinations.generate( combinations.times(test_base.graph_only_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testReadFromDeletedTask_GraphMode(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_steps = 10 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") with self.session() as sess: get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(sess.run(get_next()), i) # Re-creating the dataset resets the iterator index, so the second iterator # reads from the same task as the first, which has been deleted. dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") with self.assertRaisesRegex(errors.FailedPreconditionError, "which has been deleted."): with self.session() as sess: get_next = self.getNext(dataset) sess.run(get_next()) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testMultipleEpochs_WorkerRestart(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_steps = 10 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) # Verifies the worker re-creates the task after the iterator is deleted and # the worker restarts. del get_next cluster.restart_local_workers() get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testMultipleEpochs_DispatcherRestart(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_steps = 10 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) # Verifies the worker re-creates the task after the iterator is deleted and # the dispatcher restarts. del get_next cluster.restart_dispatcher() get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) def _make_distributed_infinite_range_dataset(self, cluster, job_name=None): dataset = dataset_ops.Dataset.range(1000000).repeat() return self.make_distributed_dataset( dataset, cluster=cluster, job_name=job_name, processing_mode=ShardingPolicy.OFF, target_workers="LOCAL")
class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): def _testFromGenerator(self, generator, elem_sequence, num_repeats, requires_initialization): dataset = dataset_ops.Dataset.from_generator( generator, output_types=dtypes.int64).repeat(num_repeats).prefetch(5) self.assertDatasetProduces( dataset, elem_sequence * num_repeats, requires_initialization=requires_initialization, num_test_iterations=2) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_repeats=[1, 5], requires_initialization=[True, False]))) def testFromGeneratorUsingFn(self, num_repeats, requires_initialization): def generator(): for i in range(1, 100): yield [i] * i elem_sequence = list(generator()) self._testFromGenerator( generator, elem_sequence, num_repeats=num_repeats, requires_initialization=requires_initialization) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_repeats=[1, 5], requires_initialization=[True, False]))) def testFromGeneratorUsingList(self, num_repeats, requires_initialization): generator = lambda: [[i] * i for i in range(1, 100)] elem_sequence = list(generator()) self._testFromGenerator( generator, elem_sequence, num_repeats=num_repeats, requires_initialization=requires_initialization) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_repeats=[1, 5], requires_initialization=[True, False]))) def testFromGeneratorUsingNdarray(self, num_repeats, requires_initialization): generator = lambda: np.arange(100, dtype=np.int64) elem_sequence = list(generator()) self._testFromGenerator( generator, elem_sequence, num_repeats=num_repeats, requires_initialization=requires_initialization) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_repeats=[1, 5], requires_initialization=[True, False]))) def testFromGeneratorUsingGeneratorExpression(self, num_repeats, requires_initialization): # NOTE(mrry): Generator *expressions* are not repeatable (or in general # reusable), because they eagerly evaluate the `for` expression as # `iter(range(1, 100))` and discard the means of reconstructing # `range(1, 100)`. Wrapping the generator expression in a `lambda` makes # it repeatable. generator = lambda: ([i] * i for i in range(1, 100)) elem_sequence = list(generator()) self._testFromGenerator( generator, elem_sequence, num_repeats=num_repeats, requires_initialization=requires_initialization) @combinations.generate(test_base.default_test_combinations()) def testFromMultipleConcurrentGenerators(self): num_inner_repeats = 5 num_outer_repeats = 100 def generator(): for i in range(1, 10): yield ([i] * i, [i, i**2, i**3]) input_list = list(generator()) # The interleave transformation is essentially a flat map that # draws from multiple input datasets concurrently (in a cyclic # fashion). By placing `Dataset.from_generator()` inside an # interleave, we test its behavior when multiple iterators are # active at the same time; by additionally prefetching inside the # interleave, we create the possibility of parallel (modulo GIL) # invocations to several iterators created by the same dataset. def interleave_fn(_): return (dataset_ops.Dataset.from_generator( generator, output_types=(dtypes.int64, dtypes.int64), output_shapes=([None], [3])).repeat(num_inner_repeats).prefetch(5)) dataset = dataset_ops.Dataset.range(num_outer_repeats).interleave( interleave_fn, cycle_length=10, block_length=len(input_list)) get_next = self.getNext(dataset) for _ in range(num_inner_repeats * num_outer_repeats): for elem in input_list: val0, val1 = self.evaluate(get_next()) self.assertAllEqual(elem[0], val0) self.assertAllEqual(elem[1], val1) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # TODO(b/67868766): Reenable this when the source of flakiness is discovered. def _testFromGeneratorsRunningInParallel(self): num_parallel_iterators = 3 # Define shared state that multiple iterator instances will access to # demonstrate their concurrent activity. lock = threading.Lock() condition = threading.Condition(lock) next_ticket = [0] # GUARDED_BY(lock) def generator(): # NOTE(mrry): We yield one element before the barrier, because # the current implementation of `Dataset.interleave()` must # fetch one element from each incoming dataset to start the # prefetching. yield 0 # Define a barrier that `num_parallel_iterators` iterators must enter # before any can proceed. Demonstrates that multiple iterators may be # active at the same time. condition.acquire() ticket = next_ticket[0] next_ticket[0] += 1 if ticket == num_parallel_iterators - 1: # The last iterator to join the barrier notifies the others. condition.notify_all() else: # Wait until the last iterator enters the barrier. while next_ticket[0] < num_parallel_iterators: condition.wait() condition.release() yield 1 # As in `testFromMultipleConcurrentGenerators()`, we use a combination of # `Dataset.interleave()` and `Dataset.prefetch()` to cause multiple # iterators to be active concurrently. def interleave_fn(_): return dataset_ops.Dataset.from_generator( generator, output_types=dtypes.int64, output_shapes=[]).prefetch(2) dataset = dataset_ops.Dataset.range(num_parallel_iterators).interleave( interleave_fn, cycle_length=num_parallel_iterators, block_length=1) get_next = self.getNext(dataset) for elem in [0, 1]: for _ in range(num_parallel_iterators): self.assertAllEqual(elem, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorImplicitConversion(self): def generator(): yield [1] yield [2] yield [3] for dtype in [dtypes.int8, dtypes.int32, dtypes.int64]: dataset = dataset_ops.Dataset.from_generator(generator, output_types=dtype, output_shapes=[1]) get_next = self.getNext(dataset) for expected in [[1], [2], [3]]: next_val = self.evaluate(get_next()) self.assertEqual(dtype.as_numpy_dtype, next_val.dtype) self.assertAllEqual(expected, next_val) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorString(self): def generator(): yield "foo" yield b"bar" yield u"baz" dataset = dataset_ops.Dataset.from_generator( generator, output_types=dtypes.string, output_shapes=[]) self.assertDatasetProduces(dataset, expected_output=[b"foo", b"bar", b"baz"]) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorDict(self): def generator(): yield {"a": "foo", "b": [1, 2]} yield {"a": "bar", "b": [3, 4]} yield {"a": "baz", "b": [5, 6]} dataset = dataset_ops.Dataset.from_generator(generator, output_types={ "a": dtypes.string, "b": dtypes.int32 }, output_shapes={ "a": [], "b": [None] }) self.assertDatasetProduces(dataset, expected_output=[{ "a": b"foo", "b": [1, 2] }, { "a": b"bar", "b": [3, 4] }, { "a": b"baz", "b": [5, 6] }]) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorTypeError(self): def generator(): yield np.array([1, 2, 3], dtype=np.int64) yield np.array([4, 5, 6], dtype=np.int64) yield "ERROR" yield np.array([7, 8, 9], dtype=np.int64) dataset = dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64, output_shapes=[3]) get_next = self.getNext(dataset) self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) with self.assertRaisesOpError("The expected type was int64"): self.evaluate(get_next()) self.assertAllEqual([7, 8, 9], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorShapeError(self): def generator(): yield np.array([1, 2, 3], dtype=np.int64) yield np.array([4, 5, 6], dtype=np.int64) yield np.array([7, 8, 9, 10], dtype=np.int64) yield np.array([11, 12, 13], dtype=np.int64) dataset = dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64, output_shapes=[3]) get_next = self.getNext(dataset) self.assertAllEqual([1, 2, 3], self.evaluate(get_next())) self.assertAllEqual([4, 5, 6], self.evaluate(get_next())) with self.assertRaisesOpError(r"element of shape \(3,\) was expected"): self.evaluate(get_next()) self.assertAllEqual([11, 12, 13], self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorStructureError(self): def generator(): yield 1, 2 yield 3, 4 yield 5 yield 6, 7, 8 yield 9, 10 dataset = dataset_ops.Dataset.from_generator( generator, output_types=(dtypes.int64, dtypes.int64)) get_next = self.getNext(dataset) self.assertEqual((1, 2), self.evaluate(get_next())) self.assertEqual((3, 4), self.evaluate(get_next())) with self.assertRaisesOpError( r"The expected structure was \(tf\.int64, tf\.int64\)"): self.evaluate(get_next()) with self.assertRaisesOpError( r"The expected structure was \(tf\.int64, tf\.int64\)"): self.evaluate(get_next()) self.assertEqual((9, 10), self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorHeterogeneous(self): def generator(): yield 1 yield [2, 3] dataset = dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) self.assertDatasetProduces(dataset, expected_output=[1, [2, 3]]) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorStopShort(self): def generator(): yield 0 yield 1 yield 2 dataset = dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) get_next = self.getNext(dataset) self.assertAllEqual(0, self.evaluate(get_next())) self.assertAllEqual(1, self.evaluate(get_next())) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorDestructorCalled(self): # Use an `Event` to signal that the generator has been deleted. event = threading.Event() class GeneratorWrapper(object): def __iter__(self): return self def next(self): return self.__next__() def __next__(self): return 42 def __del__(self): event.set() dataset = dataset_ops.Dataset.from_generator( GeneratorWrapper, output_types=dtypes.int64).take(2) get_next = self.getNext(dataset) self.assertAllEqual(42, self.evaluate(get_next())) self.assertAllEqual(42, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # Test that `GeneratorWrapper` object is destroyed when the # iterator terminates (and the generator iterator is deleted). self.assertTrue(event.is_set()) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorWithArgs(self): def flat_map_fn(elem): def generator_with_arg(n): for _ in range(n): yield np.array(n, dtype=np.int64) return dataset_ops.Dataset.from_generator( generator_with_arg, output_types=dtypes.int64, output_shapes=(), args=(elem, )) dataset = dataset_ops.Dataset.range(5).flat_map(flat_map_fn) self.assertDatasetProduces( dataset, expected_output=[1, 2, 2, 3, 3, 3, 4, 4, 4, 4]) @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorWithTwoArgs(self): def flat_map_fn(elem, message): def generator_with_arg(n, msg): for i in range(n): yield i, msg return dataset_ops.Dataset.from_generator( generator_with_arg, output_types=(dtypes.int64, dtypes.string), output_shapes=((), ()), args=(elem, message)) dataset = dataset_ops.Dataset.zip( (dataset_ops.Dataset.range(5), dataset_ops.Dataset.from_tensors("Hi!").repeat(None) )).flat_map(flat_map_fn) self.assertDatasetProduces(dataset, expected_output=[(0, b"Hi!"), (0, b"Hi!"), (1, b"Hi!"), (0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]) @combinations.generate(test_base.default_test_combinations()) def testGeneratorDatasetFinalizeFunctionCalled(self): # NOTE(mrry): This test tests the internal `_GeneratorDataset`, # which affords more control over what the finalize function can do than # the `Dataset.from_generator()` wrapper. # Use an `Event` to signal that the generator has been deleted. event = threading.Event() def finalize_fn(_): def finalize_py_func(): event.set() return 0 return script_ops.py_func(finalize_py_func, [], [dtypes.int64], stateful=True) dummy = constant_op.constant(37) dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x, lambda x: x, finalize_fn).take(2) get_next = self.getNext(dataset) self.assertAllEqual(37, self.evaluate(get_next())) self.assertAllEqual(37, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testSharedName(self): def generator(): for _ in range(10): yield [20] dataset = dataset_ops.Dataset.from_generator( generator, output_types=(dtypes.int64)) get_next = self.getNext(dataset, requires_initialization=True, shared_name="shared_dataset") self.assertAllEqual([20], self.evaluate(get_next()))
class AutoShardTest(data_service_test_base.TestBase, tf_record_test_base.TFRecordTestBase, parameterized.TestCase): """Tests auto-sharding datasets with tf.data service.""" def setUp(self): super(AutoShardTest, self).setUp() self._num_files = 10 self._num_records = 10 self._filenames = self._createFiles() @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.DATA, ShardingPolicy.FILE_OR_DATA ]))) def testRangeDataset_AutoShard(self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) self.assertDatasetProduces(dataset, [1, 6, 11, 16]) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_FileShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE) with self.assertRaisesRegex(errors.NotFoundError, "Found an unshardable source dataset"): self.getDatasetOutput(dataset) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(worker_index=[distribute.SHARD_HINT, 0, 5]))) def testRangeDataset_ShardHint(self, worker_index): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) # With HINT sharding, `num_shards` should be `SHARD_HINT`; `index` can be # any value. dataset = dataset.shard(num_shards=distribute.SHARD_HINT, index=worker_index) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) self.assertDatasetProduces(dataset, [1, 6, 11, 16]) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_InvalidWorkerIndexUsingShardHint(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) # With HINT sharding, `SHARD_HINT` should be passed to `num_shards`, not # `index`. with self.assertRaisesRegex( errors.InvalidArgumentError, r"Index must be between 0 and 4 \(currently index = -1\)."): dataset = dataset.shard(num_shards=5, index=distribute.SHARD_HINT) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_NoShardHint(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) # No SHARD_HINT is provided. The given sharding arguments will be used. dataset = dataset.shard(num_shards=1, index=0) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) self.assertDatasetProduces(dataset, list(range(20))) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.OFF, ShardingPolicy.FILE_OR_DATA ]))) def testRangeDataset_ShardHintUsedInWrongShardingPolicy( self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) with self.assertRaisesRegex( errors.FailedPreconditionError, "tf.data service with " "`tf.data.experimental.service.ShardingPolicy.HINT` processing mode." ): self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_NoShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.OFF, target_workers="LOCAL") self.assertDatasetProduces(dataset, list(range(20))) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_OneWorker(self): """Makes sure shards from all workers form the complete dataset.""" cluster = _make_service_cluster(num_workers=1, local_shard_index=0) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) self.assertDatasetProduces(dataset, list(range(20))) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_ReadFromAllWorkers(self): """Makes sure shards from all workers form the complete dataset.""" cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA, target_workers="ANY") with self.assertRaisesRegex( errors.InvalidArgumentError, "Static sharding requires reading from local workers"): self.getDatasetOutput(dataset) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.FILE_OR_DATA, ShardingPolicy.FILE ]))) def testTFRecordDataset_AutoShard(self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy, target_workers="LOCAL") expected = [ b"Record %d of file %d" % (record, file) for file in (3, 8) for record in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.FILE_OR_DATA, ShardingPolicy.FILE ]))) def testTFRecordDataset_ShuffleFileList(self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=True) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) expected = [ b"Record %d of file %d" % (record, file) for file in (3, 8) for record in range(0, 10) ] self.assertDatasetProduces(dataset, expected, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_DataShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.DATA) expected = [ b"Record %d of file %d" % (record, file) for file in range(0, 10) for record in (3, 8) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_HintDataShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) expected = [ b"Record %d of file %d" % (record, file) for file in range(0, 10) for record in (3, 8) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_HintFileShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) expected = [ b"Record %d of file %d" % (record, file) for file in (3, 8) for record in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_NoShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.OFF, target_workers="LOCAL") expected = [ b"Record %d of file %d" % (record, file) for file in range(0, 10) for record in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_ReadFromAllWorkers(self): """Makes sure shards from all workers form the complete dataset.""" cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA, target_workers="ANY") with self.assertRaisesRegex( errors.InvalidArgumentError, "Static sharding requires reading from local workers"): self.getDatasetOutput(dataset) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.FILE_OR_DATA, ShardingPolicy.FILE ]))) def testTFRecordDataset_FewerFilesThanWorkers(self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames[:4], shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) with self.assertRaisesRegex( errors.InvalidArgumentError, "not enough for the required 5 shards/workers."): self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_FewerFilesThanWorkers_HintShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames[:4], shuffle=False) dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) with self.assertRaisesRegex( errors.InvalidArgumentError, "not enough for the required 5 shards/workers."): self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_FewerFilesThanWorkers_DataShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames[:4], shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.DATA) expected = [ b"Record %d of file %d" % (record, file) for file in range(0, 4) for record in (3, 8) ] self.assertDatasetProduces(dataset, expected, assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.FILE_OR_DATA, ShardingPolicy.DATA ]))) def testBatchDataset(self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = dataset.batch(batch_size=3, drop_remainder=False) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) self.assertDatasetProduces(dataset, [[3, 4, 5], [18, 19]]) @combinations.generate(test_base.default_test_combinations()) def testInterleaveDataset(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.interleave(readers.TFRecordDataset, cycle_length=10, num_parallel_calls=dataset_ops.AUTOTUNE) dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) expected = [ b"Record %d of file %d" % (record, file) for record in range(0, 10) for file in (3, 8) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testZipDataset(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset1 = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset1 = dataset1.interleave(readers.TFRecordDataset, cycle_length=10, num_parallel_calls=dataset_ops.AUTOTUNE) dataset2 = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset2 = dataset2.interleave(readers.TFRecordDataset, cycle_length=10, num_parallel_calls=dataset_ops.AUTOTUNE) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) expected = [(b"Record %d of file %d" % (record, file), b"Record %d of file %d" % (record, file)) for record in range(0, 10) for file in (3, 8)] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testConcatenateDataset(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset1 = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset1 = dataset1.interleave(readers.TFRecordDataset, cycle_length=10, num_parallel_calls=dataset_ops.AUTOTUNE) dataset2 = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset2 = dataset2.interleave(readers.TFRecordDataset, cycle_length=10, num_parallel_calls=dataset_ops.AUTOTUNE) dataset = dataset1.concatenate(dataset2) dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) expected = [ b"Record %d of file %d" % (record, file) for record in range(0, 10) for file in (3, 8) ] expected += expected self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testEmptyDataset(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.range(0) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) self.assertDatasetProduces(dataset, []) @combinations.generate(test_base.default_test_combinations()) def testAnonymousPorts(self): cluster = _make_service_cluster( num_workers=5, local_shard_index=3, worker_addresses=["localhost:%port%" for _ in range(5)]) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) self.assertDatasetProduces(dataset, [3, 8, 13, 18]) @combinations.generate(test_base.default_test_combinations()) def testNamedPorts(self): cluster = _make_service_cluster( num_workers=5, local_shard_index=3, worker_addresses=["localhost:%port_worker%" for _ in range(5)]) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) self.assertDatasetProduces(dataset, [3, 8, 13, 18]) @combinations.generate(test_base.default_test_combinations()) def testInvalidPorts(self): with self.assertRaisesRegex(RuntimeError, "The worker's address is not configured"): _ = _make_service_cluster( num_workers=5, local_shard_index=0, worker_addresses=["localhost:worker" for _ in range(5)]) @combinations.generate(test_base.default_test_combinations()) def testEmptyWorkerList(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=1, worker_addresses=[]) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) with self.assertRaisesRegex(errors.NotFoundError, "Worker .* is not in the workers list."): self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testWorkerNotFound(self): worker_addresses = [f"fake_worker_{i}" for i in range(5)] with self.assertRaisesRegex(RuntimeError, "The worker's address is not configured"): _ = _make_service_cluster(num_workers=5, local_shard_index=0, worker_addresses=worker_addresses) @combinations.generate(test_base.default_test_combinations()) def testMoreWorkersThanConfigured(self): worker_addresses = ["localhost:%port%"] with self.assertRaisesRegex( RuntimeError, "other workers are already running at the configured host"): _ = _make_service_cluster(num_workers=5, local_shard_index=1, worker_addresses=worker_addresses) @combinations.generate(test_base.default_test_combinations()) def testNoLocalWorkers(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=0, num_remote_workers=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) with self.assertRaisesRegex( errors.InvalidArgumentError, "Local reads or static sharding require local tf.data workers" ): self.getDatasetOutput(dataset) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=list(ShardingPolicy)))) def testEnumerateShardingPolicies(self, sharding_policy): """Verifies tf.data service handles every sharding policy with no errors.""" cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) self.getDatasetOutput(dataset)
class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testBasic(self): components = ( np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), np.array([9.0, 10.0, 11.0, 12.0]) ) def dataset_fn(count=5, buffer_size=None, seed=0): repeat_dataset = ( dataset_ops.Dataset.from_tensor_slices(components).repeat(count)) if buffer_size: shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed) self.assertEqual( tuple([c.shape[1:] for c in components]), dataset_ops.get_legacy_output_shapes(shuffle_dataset)) return shuffle_dataset else: return repeat_dataset # First run without shuffling to collect the "ground truth". get_next = self.getNext(dataset_fn()) unshuffled_elements = [] for _ in range(20): unshuffled_elements.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # Assert that the shuffled dataset has the same elements as the # "ground truth". get_next = self.getNext(dataset_fn(buffer_size=100, seed=37)) shuffled_elements = [] for _ in range(20): shuffled_elements.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertAllEqual(sorted(unshuffled_elements), sorted(shuffled_elements)) # Assert that shuffling twice with the same seeds gives the same sequence. get_next = self.getNext(dataset_fn(buffer_size=100, seed=37)) reshuffled_elements_same_seed = [] for _ in range(20): reshuffled_elements_same_seed.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertEqual(shuffled_elements, reshuffled_elements_same_seed) # Assert that shuffling twice with a different seed gives a different # permutation of the same elements. get_next = self.getNext(dataset_fn(buffer_size=100, seed=137)) reshuffled_elements_different_seed = [] for _ in range(20): reshuffled_elements_different_seed.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed) self.assertAllEqual( sorted(shuffled_elements), sorted(reshuffled_elements_different_seed)) # Assert that the shuffled dataset has the same elements as the # "ground truth" when the buffer size is smaller than the input # dataset. get_next = self.getNext(dataset_fn(buffer_size=2, seed=37)) reshuffled_elements_small_buffer = [] for _ in range(20): reshuffled_elements_small_buffer.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertAllEqual( sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer)) # Test the case of shuffling an empty dataset. get_next = self.getNext(dataset_fn(count=0, buffer_size=100, seed=37)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(combinations.combine(tf_api_version=1, mode="graph")) def testSeedZero(self): """Test for same behavior when the seed is a Python or Tensor zero.""" iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.range(10).shuffle(10, seed=0)) get_next = iterator.get_next() elems = [] with self.cached_session() as sess: for _ in range(10): elems.append(sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder)) get_next = iterator.get_next() with self.cached_session() as sess: sess.run(iterator.initializer, feed_dict={seed_placeholder: 0}) for elem in elems: self.assertEqual(elem, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.default_test_combinations()) def testDefaultArguments(self): components = [0, 1, 2, 3, 4] dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle( 5).repeat() get_next = self.getNext(dataset) counts = collections.defaultdict(lambda: 0) for _ in range(10): for _ in range(5): counts[self.evaluate(get_next())] += 1 for i in range(5): self.assertEqual(10, counts[i]) @combinations.generate( combinations.times( test_base.graph_only_combinations(), combinations.combine(reshuffle=[True, False]), combinations.combine(graph_seed=38, op_seed=None) + combinations.combine(graph_seed=None, op_seed=42) + combinations.combine(graph_seed=38, op_seed=42))) def testShuffleSeed(self, reshuffle, graph_seed, op_seed): results = [] for _ in range(2): with ops.Graph().as_default() as g: random_seed.set_random_seed(graph_seed) dataset = dataset_ops.Dataset.range(10).shuffle( 10, seed=op_seed, reshuffle_each_iteration=reshuffle).repeat(3) iterator = dataset_ops.make_one_shot_iterator(dataset) next_element = iterator.get_next() run_results = [] with self.session(graph=g) as sess: for _ in range(30): run_results.append(sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) results.append(run_results) self.assertAllEqual(results[0], results[1]) # TODO(b/117581999): enable this test for eager-mode. @combinations.generate( combinations.times( test_base.graph_only_combinations(), combinations.combine( reshuffle=[True, False], initializable=[True, False]))) def testMultipleIterators(self, reshuffle, initializable): with ops.Graph().as_default() as g: dataset = dataset_ops.Dataset.range(100).shuffle( 10, reshuffle_each_iteration=reshuffle).repeat(3) if initializable: iterators = [dataset_ops.make_initializable_iterator(dataset) for _ in range(2)] else: iterators = [dataset_ops.make_one_shot_iterator(dataset) for _ in range(2)] results = [] with self.session(graph=g) as sess: for iterator in iterators: if initializable: sess.run(iterator.initializer) next_element = iterator.get_next() run_results = [] for _ in range(300): run_results.append(sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) results.append(run_results) self.assertNotEqual(results[0], results[1]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(reshuffle=[True, False], seed=[None, 42]))) def testReshuffleRepeatEpochs(self, reshuffle, seed): dataset = dataset_ops.Dataset.range(10).shuffle( 10, seed=seed, reshuffle_each_iteration=reshuffle).repeat(2) next_element = self.getNext(dataset) first_epoch = [] for _ in range(10): first_epoch.append(self.evaluate(next_element())) second_epoch = [] for _ in range(10): second_epoch.append(self.evaluate(next_element())) self.assertEqual(first_epoch == second_epoch, not reshuffle) @combinations.generate( combinations.times( combinations.combine(tf_api_version=2, mode="eager"), combinations.combine(reshuffle=[True, False], seed=[None, 42]))) def testReshuffleIterationEpochs(self, reshuffle, seed): dataset = dataset_ops.Dataset.range(10).shuffle( 10, seed=seed, reshuffle_each_iteration=reshuffle) first_epoch = [] for elem in dataset: first_epoch.append(elem.numpy()) second_epoch = [] for elem in dataset: second_epoch.append(elem.numpy()) self.assertEqual(first_epoch == second_epoch, not reshuffle) @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) def testShuffleV2ResourceCapture(self): def make_dataset(): ids = dataset_ops.Dataset.range(10) ids = ids.shuffle(1) def interleave_fn(dataset, _): return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.interleave(functools.partial(interleave_fn, ids)) return dataset results = [] for elem in make_dataset(): results.append(elem.numpy()) self.assertAllEqual(results, range(10)) @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(reshuffle=[True, False], seed=[None, 42]))) def testReshuffleSeparateTransformations(self, reshuffle, seed): dataset = dataset_ops.Dataset.range(10) first_epoch = [] for elem in dataset.shuffle( 10, seed=seed, reshuffle_each_iteration=reshuffle): first_epoch.append(elem.numpy()) second_epoch = [] for elem in dataset.shuffle( 10, seed=seed, reshuffle_each_iteration=reshuffle): second_epoch.append(elem.numpy()) self.assertEqual(first_epoch != second_epoch, seed is None) @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) def testShuffleV2InFunction(self): counter_var = variables.Variable(0) @function.defun def consume(): ds = dataset_ops.Dataset.range(10) ds = ds.shuffle(1) for _ in ds: counter_var.assign(counter_var + 1) consume() self.assertAllEqual(self.evaluate(counter_var), 10) @combinations.generate(test_base.default_test_combinations()) def testEmptyDataset(self): dataset = dataset_ops.Dataset.from_tensors(1) def map_fn(x): with ops.control_dependencies([check_ops.assert_equal(x, 0)]): return x dataset = dataset.map(map_fn) dataset = dataset.cache() dataset = dataset.shuffle(buffer_size=10).repeat() get_next = self.getNext(dataset) # First time around, we get an error for the failed assertion. with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) # Second time around, we get an EOF because the cached dataset is empty. with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next())
class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( input_values=[[4, 5, 6]], cycle_length=1, block_length=1, expected_elements=[[ 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6 ]]) + combinations.combine( input_values=[[4, 5, 6]], cycle_length=2, block_length=1, expected_elements=[[ 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5, 6, 6 ]]) + combinations.combine( input_values=[[4, 5, 6]], cycle_length=2, block_length=3, expected_elements=[[ 4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6 ]]) + combinations.combine( input_values=[[4, 5, 6]], cycle_length=7, block_length=2, expected_elements=[[ 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6 ]]) + combinations.combine(input_values=[[4, 0, 6]], cycle_length=2, block_length=1, expected_elements=[[ 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6 ]]))) def testPythonImplementation(self, input_values, cycle_length, block_length, expected_elements): input_lists = _repeat(input_values, 2) for expected, produced in zip( expected_elements, _interleave(input_lists, cycle_length, block_length)): self.assertEqual(expected, produced) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=1, block_length=3, num_parallel_calls=[None, 1]) + combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=2, block_length=[1, 3], num_parallel_calls=[None, 1, 2]) + combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=7, block_length=2, num_parallel_calls=[None, 1, 3, 5, 7]) + combinations.combine(input_values=[np.int64([4, 5, 6, 7])], cycle_length=dataset_ops.AUTOTUNE, block_length=3, num_parallel_calls=[None, 1]) + combinations.combine( input_values=[np.int64([]), np.int64([0, 0, 0])], cycle_length=2, block_length=3, num_parallel_calls=[None]) + combinations.combine(input_values=[np.int64([4, 0, 6])], cycle_length=2, block_length=3, num_parallel_calls=[None, 1, 2]))) def testInterleaveDataset(self, input_values, cycle_length, block_length, num_parallel_calls): count = 2 dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat( count).interleave( lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), cycle_length, block_length, num_parallel_calls) expected_output = [ element for element in _interleave(_repeat(input_values, count), cycle_length, block_length) ] self.assertDatasetProduces(dataset, expected_output) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( input_values=[np.float32([1., np.nan, 2., np.nan, 3.])], cycle_length=1, block_length=3, num_parallel_calls=[None, 1]) + combinations.combine( input_values=[np.float32([1., np.nan, 2., np.nan, 3.])], cycle_length=2, block_length=[1, 3], num_parallel_calls=[None, 1, 2]) + combinations.combine( input_values=[np.float32([1., np.nan, 2., np.nan, 3.])], cycle_length=7, block_length=2, num_parallel_calls=[None, 1, 3, 5, 7]))) def testInterleaveDatasetError(self, input_values, cycle_length, block_length, num_parallel_calls): dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map( lambda x: array_ops.check_numerics(x, "message")).interleave( dataset_ops.Dataset.from_tensors, cycle_length, block_length, num_parallel_calls) get_next = self.getNext(dataset) for value in input_values: if np.isnan(value): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) else: self.assertEqual(value, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testInterleaveSparse(self): def _map_fn(i): return sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) def _interleave_fn(x): return dataset_ops.Dataset.from_tensor_slices( sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) dataset = dataset_ops.Dataset.range(10).map(_map_fn).interleave( _interleave_fn, cycle_length=1) get_next = self.getNext(dataset) for i in range(10): for j in range(2): expected = [i, 0] if j % 2 == 0 else [0, -i] self.assertAllEqual(expected, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=1, block_length=3, num_parallel_calls=1) + combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=2, block_length=[1, 3], num_parallel_calls=[1, 2]) + combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=7, block_length=2, num_parallel_calls=[1, 3, 5, 7]) + combinations.combine(input_values=[np.int64([4, 5, 6, 7])], cycle_length=dataset_ops.AUTOTUNE, block_length=3, num_parallel_calls=1) + combinations.combine(input_values=[np.int64([4, 0, 6])], cycle_length=2, block_length=3, num_parallel_calls=[1, 2]))) def testSloppyInterleaveDataset(self, input_values, cycle_length, block_length, num_parallel_calls): count = 2 dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat( count).interleave( lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), cycle_length, block_length, num_parallel_calls) options = dataset_ops.Options() options.experimental_deterministic = False dataset = dataset.with_options(options) expected_output = [ element for element in _interleave(_repeat(input_values, count), cycle_length, block_length) ] get_next = self.getNext(dataset) actual_output = [] for _ in range(len(expected_output)): actual_output.append(self.evaluate(get_next())) self.assertAllEqual(expected_output.sort(), actual_output.sort()) @combinations.generate(test_base.default_test_combinations()) def testInterleaveMap(self): dataset = dataset_ops.Dataset.range(100) def interleave_fn(x): dataset = dataset_ops.Dataset.from_tensors(x) return dataset.map(lambda x: x + x) dataset = dataset.interleave(interleave_fn, cycle_length=5) dataset = dataset.interleave(interleave_fn, cycle_length=5) self.assertDatasetProduces(dataset, [4 * x for x in range(100)]) @combinations.generate(test_base.default_test_combinations()) def testParallelInterleaveCached(self): dataset = dataset_ops.Dataset.range(5) dataset = dataset.cache(os.path.join(self.get_temp_dir(), "cache_dir")) def interleave_fn(x): return dataset_ops.Dataset.from_tensors(x) dataset = dataset.interleave(interleave_fn, cycle_length=2, num_parallel_calls=2) self.assertDatasetProduces(dataset, list(range(5))) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(local_determinism=[None, True, False], global_determinism=[True, False]))) def testDeterminismConfiguration(self, local_determinism, global_determinism): expect_determinism = local_determinism or (local_determinism is None and global_determinism) elements = list(range(1000)) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds dataset = dataset_ops.Dataset.from_tensor_slices(elements) dataset = dataset.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10, deterministic=local_determinism) opts = dataset_ops.Options() opts.experimental_deterministic = global_determinism dataset = dataset.with_options(opts) return dataset self.checkDeterminism(dataset_fn, expect_determinism, elements)
class LegacySnapshotDatasetTest( reader_dataset_ops_test_base.TFRecordDatasetTestBase, parameterized.TestCase): def setUp(self): super(LegacySnapshotDatasetTest, self).setUp() self.removeTFRecords() tmpdir = self.get_temp_dir() tmpdir = os.path.join(tmpdir, "snapshot") os.mkdir(tmpdir) self.snapshot_dir = tmpdir def tearDown(self): super(LegacySnapshotDatasetTest, self).tearDown() shutil.rmtree(self.snapshot_dir) def removeTFRecords(self): for filename in self.test_filenames: os.remove(filename) self.test_filenames = [] def setUpTFRecord(self, num_files=10, num_records=10): self._num_files = num_files self._num_records = num_records self.test_filenames = self._createFiles() def makeSnapshotDirectory(self): return self.snapshot_dir def assertSnapshotDirectoryContains(self, directory, num_fingerprints, num_runs_per_fp, num_snapshot_files): dirlist_raw = os.listdir(directory) dirlist = [] # Ignore the graphdef pbtxts we write for debugging purposes. for i in range(len(dirlist_raw)): if not dirlist_raw[i].endswith("-graph.pbtxt"): dirlist.append(dirlist_raw[i]) self.assertLen(dirlist, num_fingerprints) for i in range(num_fingerprints): fingerprint_dir = os.path.join(directory, dirlist[i]) fingerprint_dir_list = sorted(os.listdir(fingerprint_dir)) self.assertLen(fingerprint_dir_list, num_runs_per_fp + 1) self.assertEqual(fingerprint_dir_list[num_runs_per_fp], "snapshot.metadata") for j in range(num_runs_per_fp): run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j]) run_dirlist = sorted(os.listdir(run_dir)) self.assertLen(run_dirlist, num_snapshot_files) file_counter = 0 for filename in run_dirlist: self.assertEqual(filename, "%08d.snapshot" % file_counter) file_counter += 1 @combinations.generate(test_base.default_test_combinations()) def testWriteDifferentPipelinesInOneDirectory(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1000))) dataset = dataset_ops.Dataset.range(1001) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1001))) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotMultipleSimultaneous(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir)) next1 = self.getNext(dataset1) dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir)) next2 = self.getNext(dataset2) for i in range(0, 1000): self.assertEqual(i, self.evaluate(next1())) self.assertEqual(i, self.evaluate(next2())) # we check that only one copy of the metadata has been written, and the # one that lost the race would be in passthrough mode. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testGetNextCreatesDir(self): tmpdir = self.snapshot_dir # We create two iterators but call getNext on only one. dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir)) next1 = self.getNext(dataset1) dataset2 = dataset_ops.Dataset.range(1001) dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir)) _ = self.getNext(dataset2) for _ in range(1000): self.evaluate(next1()) # We check that only one directory is created. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testWriteSnapshotSimpleSuccessful(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset, list(range(1000))) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testWriteSnapshotRepeatAfterwards(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testWriteSnapshotMixTypes(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) def map_fn(x): return (x, string_ops.as_string(x), string_ops.as_string(2 * x), 2 * x) dataset = dataset.map(map_fn) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) dataset = dataset.repeat(10) expected = [] for i in range(10): expected.append((i, str(i), str(2 * i), 2 * i)) self.assertDatasetProduces(dataset, expected * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testSpecifySnapshotNameWriteAndRead(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, snapshot_name="my_custom_snapshot")) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) self.assertTrue( os.path.exists(os.path.join(tmpdir, "custom-my_custom_snapshot"))) self.assertTrue( os.path.exists( os.path.join(tmpdir, "custom-my_custom_snapshot", "custom"))) @combinations.generate(test_base.default_test_combinations()) def testForcePassthroughMode(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, mode="passthrough")) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 0, 0, 0) @combinations.generate(test_base.default_test_combinations()) def testForceWriteMode(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir, mode="write")) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) # We will end up writing 10 different runs. self.assertSnapshotDirectoryContains(tmpdir, 1, 10, 1) @combinations.generate(test_base.default_test_combinations()) def testForceReadMode(self): tmpdir = self.snapshot_dir # We write a copy of the snapshot first. dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, mode="write", snapshot_name="my_custom_snapshot")) self.assertDatasetProduces(dataset, list(range(10))) # We move the run to a new name. shutil.move(os.path.join(tmpdir, "custom-my_custom_snapshot"), os.path.join(tmpdir, "custom-my_custom_snapshot_2")) # Even though the snapshot.metadata is pointing to the old run that no # longer exists after we moved, we force it to read from the run we specify. dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, mode="read", snapshot_name="my_custom_snapshot_2")) self.assertDatasetProduces(dataset, list(range(10))) # We should still have one snapshot and one run. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testForceReadNonexistentSnapshot(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) with self.assertRaises(errors.NotFoundError): dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, mode="read")) get_next = self.getNext(dataset) self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testForceReadNonexistentNamedSnapshot(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) with self.assertRaises(errors.NotFoundError): dataset = dataset.apply( snapshot.legacy_snapshot( tmpdir, mode="read", snapshot_name="my_nonexistent_snapshot")) get_next = self.getNext(dataset) self.evaluate(get_next()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testReadSnapshotBackAfterWrite(self, compression): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset2, expected) @combinations.generate(test_base.default_test_combinations()) def testReadShuffledSnapshotAfterWrite(self): self.setUpTFRecord(num_files=10, num_records=50) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 50) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=100)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=100, shuffle_on_read=True)) next2 = self.getNext(dataset2) res1 = self.evaluate(next2()) res2 = self.evaluate(next2()) res3 = self.evaluate(next2()) res4 = self.evaluate(next2()) res5 = self.evaluate(next2()) # make sure that we don't read the file back in the same order. self.assertNotEqual([res1, res2, res3, res4, res5], expected[0:5]) # make sure all the elements are still there dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=100, shuffle_on_read=True)) self.assertDatasetProduces(dataset3, expected, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testReadShuffledSnapshotWithSeedAfterWrite(self): self.setUpTFRecord(num_files=10, num_records=50) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 50) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10, shuffle_on_read=True, shuffle_seed=123456)) next2 = self.getNext(dataset2) dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10, shuffle_on_read=True, shuffle_seed=123456)) next3 = self.getNext(dataset3) # make sure that the items are read back in the same order for both datasets for _ in range(500): res2 = self.evaluate(next2()) res3 = self.evaluate(next3()) self.assertEqual(res2, res3) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testReadSnapshotParallelAfterWrite(self, compression): self.setUpTFRecord(10, 4000) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 4000) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, reader_buffer_size=10, compression=compression)) self.assertDatasetProduces(dataset, expected, assert_items_equal=True) # remove the original files and try to read the data back only from # snapshot. self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, reader_buffer_size=10, compression=compression)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) # Not testing Snappy here because Snappy reads currently require a lot of # memory. @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.times( combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP ]), combinations.combine(threads=2, size=[1, 2]) + combinations.combine(threads=8, size=[1, 4, 8])))) def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads, size): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, compression=compression, num_writer_threads=threads, writer_buffer_size=size)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from # snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testSameFingerprintWithDifferentInitializationOrder(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(0, 100) dataset2 = dataset_ops.Dataset.range(100, 200) dataset3 = dataset_ops.Dataset.range(200, 300) dataset = dataset1.concatenate(dataset2).concatenate(dataset3) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(300))) dataset4 = dataset_ops.Dataset.range(200, 300) dataset5 = dataset_ops.Dataset.range(100, 200) dataset6 = dataset_ops.Dataset.range(0, 100) dataset = dataset6.concatenate(dataset5).concatenate(dataset4) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(300))) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testExpiredSnapshotRewrite(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply( snapshot.legacy_snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) next1 = self.getNext(dataset1) # Don't finish reading dataset1, so it is never finalized for _ in range(500): self.evaluate(next1()) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) time.sleep(2) # Creating dataset2 after we run through dataset1 due to eager mode, where # the snapshot state is determined immediately upon dataset creation. We # only want to determine the snapshot state for dataset2 after the first # snapshot has expired. dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) next2 = self.getNext(dataset2) for _ in range(500): self.evaluate(next2()) self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1) @combinations.generate(test_base.default_test_combinations()) def testSnapshotArgsCreateNewSnapshot(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10000)) next1 = self.getNext(dataset1) for _ in range(1000): self.evaluate(next1()) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) # Create second snapshot with a different shard_size_bytes dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset1.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=20000)) next2 = self.getNext(dataset2) for _ in range(1000): self.evaluate(next2()) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testSpecifyShardSize(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.from_tensor_slices([1.0]) dataset = dataset.map( lambda x: gen_array_ops.broadcast_to(x, [1024, 1024])) dataset = dataset.repeat(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10 * 1024 * 1024, compression=compression)) next_fn = self.getNext(dataset) for _ in range(10): self.evaluate(next_fn()) num_files = 1 if compression == snapshot.COMPRESSION_NONE: num_files = 3 self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files) @combinations.generate(test_base.default_test_combinations()) def testAdditionalOperationsAfterReadBack(self): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset2, expected) expected_after = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply(snapshot.legacy_snapshot(tmpdir)) dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000)) self.assertDatasetProduces(dataset3, expected_after)
class BucketBySequenceLengthTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_no_padding=[True, False]))) def testBucketDropReminder(self, param_no_padding): boundaries = [10, 20, 30] batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25, 35] n_bucket_elements = [28, 7, 6, 5] n_expected_batches = 5 # Expected sequence lengths of the individual batches. expected_lengths = [] # Expected sum of all batches with an equal sequence length. # <seq-length>: <expected-total-sum> expected_sums = {} # Expected batch sizes of batches depending on the sequence length. # <seq-length>: [batch1_size, ..., batchN_size] expected_batch_sizes = {} for length, batch_size, bucket_elements in zip(lengths, batch_sizes, n_bucket_elements): # Calculate the expected sum across all batches of a specific sequence length. expected_sums[length] = \ (bucket_elements - bucket_elements % batch_size) * length # Calculate the expected occurrence of individual batch sizes. expected_batch_sizes[length] = \ [batch_size] * (bucket_elements // batch_size) # Calculate the expected occurrence of individual sequence lengths. expected_lengths.extend([length] * (bucket_elements // batch_size)) def build_dataset(sparse): def _generator(): # Produce 1 batch for each bucket elements = [] for bucket_elements, length in zip(n_bucket_elements, lengths): # Using only full sequences (opposed to the strategy employed in `testBucket`) makes # checking the sum a lot easier. record_len = length for _ in range(bucket_elements): elements.append([1] * record_len) random.shuffle(elements) for el in elements: yield (_format_record(el, sparse), ) dataset = dataset_ops.Dataset.from_generator( _generator, (_get_record_type(sparse), ), (_get_record_shape(sparse), )) if sparse: dataset = dataset.map(lambda x: (_to_sparse_tensor(x), )) return dataset def _test_bucket_by_padding(no_padding): dataset = build_dataset(sparse=no_padding) dataset = dataset.apply( grouping.bucket_by_sequence_length(_element_length_fn, boundaries, batch_sizes, no_padding=no_padding, drop_remainder=True)) get_next = self.getNext(dataset) batches = [] for _ in range(n_expected_batches): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) generated_lengths = [] # <seq-length>: <total-sum> generated_sums = {} # <seq-length>: [<batch_size>, ...] generated_batch_sizes = {} for length, batch_size, bucket_elements in zip( lengths, batch_sizes, n_bucket_elements): # Initialize the sum across all batches. generated_sums[length] = 0 # Initialize the individual batch sizes. generated_batch_sizes[length] = [] for batch in batches: shape = batch.dense_shape if no_padding else batch.shape length = shape[1] generated_lengths.append(length) batch_size = shape[0] generated_batch_sizes[length].append(batch_size) batch_sum = batch.values.sum() if no_padding else batch.sum() generated_sums[length] += batch_sum for l in lengths: # Make sure the sum of the batch contents is correct for the individual sequence lengths. self.assertEqual( generated_sums[l], expected_sums[l], "Tensor sums did not match! " "expected: {}, generated: {}".format( expected_sums, generated_sums)) # Make sure the individual batch sizes are generated as expected. self.assertEqual( sorted(generated_batch_sizes[l]), sorted(expected_batch_sizes[l]), "Batch-sizes did not match! " "expected: {}, generated: {}".format( sorted(expected_batch_sizes[l]), sorted(generated_batch_sizes[l]))) # Make sure the generated sequence lengths appear as often as expected. self.assertEqual( sorted(generated_lengths), sorted(expected_lengths), "The generated sequence lengths did not match! " "expected: {}, generated: {}".format( sorted(expected_lengths), sorted(generated_lengths))) _test_bucket_by_padding(param_no_padding) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_no_padding=[True, False]))) def testBucket(self, param_no_padding): boundaries = [10, 20, 30] batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25, 35] def build_dataset(sparse): def _generator(): # Produce 1 batch for each bucket elements = [] for batch_size, length in zip(batch_sizes, lengths): record_len = length - 1 for _ in range(batch_size): elements.append([1] * record_len) record_len = length random.shuffle(elements) for el in elements: yield (_format_record(el, sparse), ) dataset = dataset_ops.Dataset.from_generator( _generator, (_get_record_type(sparse), ), (_get_record_shape(sparse), )) if sparse: dataset = dataset.map(lambda x: (_to_sparse_tensor(x), )) return dataset def _test_bucket_by_padding(no_padding): dataset = build_dataset(sparse=no_padding) dataset = dataset.apply( grouping.bucket_by_sequence_length(_element_length_fn, boundaries, batch_sizes, no_padding=no_padding)) get_next = self.getNext(dataset) batches = [] for _ in range(4): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) batch_sizes_val = [] lengths_val = [] for batch in batches: shape = batch.dense_shape if no_padding else batch.shape batch_size = shape[0] length = shape[1] batch_sizes_val.append(batch_size) lengths_val.append(length) if not context.executing_eagerly(): sum_check = batch.values.sum( ) if no_padding else batch.sum() self.assertEqual(sum_check, batch_size * length - 1) self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) self.assertEqual(sorted(lengths), sorted(lengths_val)) _test_bucket_by_padding(param_no_padding) def testPadToBoundary(self): boundaries = [10, 20, 30] batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25] def element_gen(): # Produce 1 batch for each bucket elements = [] for batch_size, length in zip(batch_sizes[:-1], lengths): for _ in range(batch_size): elements.append([1] * length) random.shuffle(elements) for el in elements: yield (el, ) for _ in range(batch_sizes[-1]): el = [1] * (boundaries[-1] + 5) yield (el, ) element_len = lambda el: array_ops.shape(el)[0] dataset = dataset_ops.Dataset.from_generator( element_gen, (dtypes.int64, ), ([None], )).apply( grouping.bucket_by_sequence_length( element_len, boundaries, batch_sizes, pad_to_bucket_boundary=True)) get_next = self.getNext(dataset) batches = [] for _ in range(3): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaisesOpError("bucket_boundaries"): self.evaluate(get_next()) batch_sizes_val = [] lengths_val = [] for batch in batches: batch_size = batch.shape[0] length = batch.shape[1] batch_sizes_val.append(batch_size) lengths_val.append(length) batch_sizes = batch_sizes[:-1] self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) self.assertEqual([boundary - 1 for boundary in sorted(boundaries)], sorted(lengths_val)) def testPadToBoundaryNoExtraneousPadding(self): boundaries = [3, 7, 11] batch_sizes = [2, 2, 2, 2] lengths = range(1, 11) def element_gen(): for length in lengths: yield ([1] * length, ) element_len = lambda element: array_ops.shape(element)[0] dataset = dataset_ops.Dataset.from_generator( element_gen, (dtypes.int64, ), ([None], )).apply( grouping.bucket_by_sequence_length( element_len, boundaries, batch_sizes, pad_to_bucket_boundary=True)) get_next = self.getNext(dataset) batches = [] for _ in range(5): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertAllEqual(batches[0], [[1, 0], [1, 1]]) self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0]]) self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1]]) self.assertAllEqual( batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) self.assertAllEqual( batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_no_padding=[True, False]))) def testTupleElements(self, param_no_padding): def build_dataset(sparse): def _generator(): text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] label = [1, 2, 1, 2] for x, y in zip(text, label): yield (_format_record(x, sparse), y) dataset = dataset_ops.Dataset.from_generator( generator=_generator, output_types=(_get_record_type(sparse), dtypes.int32), output_shapes=(_get_record_shape(sparse), tensor_shape.TensorShape([]))) if sparse: dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y)) return dataset def _test_tuple_elements_by_padding(no_padding): dataset = build_dataset(sparse=no_padding) dataset = dataset.apply( grouping.bucket_by_sequence_length( element_length_func=_element_length_fn, bucket_batch_sizes=[2, 2, 2], bucket_boundaries=[0, 8], no_padding=no_padding)) shapes = dataset_ops.get_legacy_output_shapes(dataset) self.assertEqual([None, None], shapes[0].as_list()) self.assertEqual([None], shapes[1].as_list()) _test_tuple_elements_by_padding(param_no_padding) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_drop_remainder=[True, False]))) def testBucketSparse(self, param_drop_remainder): # pylint: disable=g-doc-args """Tests bucketing of sparse tensors (case where `no_padding` == True). Test runs on following dataset: [ [0], [0, 1], [0, 1, 2] ... [0, ..., max_len - 1] ] Sequences are bucketed by length and batched with `batch_size` < `bucket_size`. """ min_len = 0 max_len = 100 batch_size = 7 bucket_size = 10 def _build_dataset(): input_data = [range(i + 1) for i in range(min_len, max_len)] def generator_fn(): for record in input_data: yield _format_record(record, sparse=True) dataset = dataset_ops.Dataset.from_generator( generator=generator_fn, output_types=_get_record_type(sparse=True)) dataset = dataset.map(_to_sparse_tensor) return dataset def _compute_expected_batches(drop_remainder): """Computes expected batch outputs and stores in a set.""" all_expected_sparse_tensors = set() for bucket_start_len in range(min_len, max_len, bucket_size): if drop_remainder: batch_offsets = [0] else: batch_offsets = range(0, bucket_size, batch_size) for batch_offset in batch_offsets: batch_start_len = bucket_start_len + batch_offset batch_end_len = min(batch_start_len + batch_size, bucket_start_len + bucket_size) expected_indices = [] expected_values = [] for length in range(batch_start_len, batch_end_len): for val in range(length + 1): expected_indices.append( (length - batch_start_len, val)) expected_values.append(val) expected_sprs_tensor = (tuple(expected_indices), tuple(expected_values)) all_expected_sparse_tensors.add(expected_sprs_tensor) return all_expected_sparse_tensors def _compute_batches(dataset): """Computes actual batch outputs of dataset and stores in a set.""" batch = self.getNext(dataset) all_sparse_tensors = set() with self.assertRaises(errors.OutOfRangeError): while True: output = self.evaluate(batch()) sprs_tensor = (tuple([ tuple(idx) for idx in output.indices ]), tuple(output.values)) all_sparse_tensors.add(sprs_tensor) return all_sparse_tensors dataset = _build_dataset() boundaries = range(min_len + bucket_size + 1, max_len, bucket_size) dataset = dataset.apply( grouping.bucket_by_sequence_length( _element_length_fn, boundaries, [batch_size] * (len(boundaries) + 1), no_padding=True, drop_remainder=param_drop_remainder)) batches = _compute_batches(dataset) expected_batches = _compute_expected_batches(param_drop_remainder) self.assertEqual(batches, expected_batches)
class MultiDeviceIteratorCommonTest(test_base.DatasetTestBase, parameterized.TestCase): """Tests that are common to MultiDeviceIterator and OwnedMultiDeviceIterator.""" def setUp(self): super().setUp() self._devices = self.configureDevicesForMultiDeviceTest(3) @combinations.generate( combinations.times(test_base.eager_only_combinations(), cls_combination)) def testCancelGetNextWithDevice(self, cls): ping = data_flow_ops.FIFOQueue(capacity=2, dtypes=dtypes.int64) pong = data_flow_ops.FIFOQueue(capacity=2, dtypes=dtypes.int64) @def_function.function def map_fn(v): ball = ping.dequeue() with ops.control_dependencies([pong.enqueue(ball)]): return v + ping.dequeue() dataset = dataset_ops.Dataset.range(10) dataset = dataset.map(map_fn) # We need to set prefetch_buffer_size=0 so that we can cancel the # MultiDeviceIteratorGetNextFromShardOp from eager. If # prefetch_buffer_size>0, that op runs in the background threads of the # prefetch and can only be cancelled by deleting the iterator. multi_device_iterator = cls(dataset, [self._devices[1], self._devices[2]], prefetch_buffer_size=0) @def_function.function def get_next_device1(): return multi_device_iterator.get_next(self._devices[1]) async_executor = executor.new_executor(enable_async=True) with context.executor_scope(async_executor): cancel_mgr = cancellation.CancellationManager() cancel_mgr.get_cancelable_function( get_next_device1.get_concrete_function())() # Make sure we cancel in the middle of get_next. ping.enqueue(0) pong.dequeue() cancel_mgr.start_cancel() with self.assertRaises(errors.CancelledError): async_executor.wait() # Note that fetching from upstream iterator is not cancelled with the # cancellation of get_next. ping.enqueue(0) # Cancelling a get_next on one device shouldn't cancel the # multi_device_iterator and iterators on other devices. ping.enqueue(0) ping.enqueue(0) self.assertEqual( 1, multi_device_iterator.get_next(self._devices[2]).numpy()) @combinations.generate( combinations.times(test_base.eager_only_combinations(), cls_combination)) def testEmptyDataset(self, cls): dataset = dataset_ops.Dataset.range(0) multi_device_iterator = cls( dataset, devices=[self._devices[1], self._devices[2]]) with self.assertRaises(errors.OutOfRangeError): multi_device_iterator.get_next() @combinations.generate( combinations.times(test_base.eager_only_combinations(), cls_combination)) def testEmptyDeviceList(self, cls): dataset = dataset_ops.Dataset.range(10) with self.assertRaisesRegex( errors.InvalidArgumentError, "Length for attr 'devices' of 0 must be at least minimum 1"): cls(dataset, devices=[])
class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): super(OwnedMultiDeviceIteratorTest, self).setUp() self._devices = self.configureDevicesForMultiDeviceTest(3) @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(max_buffer_size=[0, 1, 10], prefetch_buffer_size=[0, 1, 10]))) def testBasic(self, max_buffer_size, prefetch_buffer_size): dataset = dataset_ops.Dataset.range(1000) mdi = multi_device_iterator_ops.OwnedMultiDeviceIterator( dataset, [self._devices[1], self._devices[2]], max_buffer_size=max_buffer_size, prefetch_buffer_size=prefetch_buffer_size) for i, el in enumerate(mdi): self.assertEqual([i * 2, i * 2 + 1], [el[0].numpy(), el[1].numpy()]) @combinations.generate(test_base.eager_only_combinations()) def testBasicFunction(self): queue = data_flow_ops.FIFOQueue(10, dtypes.int64) @def_function.function def fn(): with ops.device(self._devices[0]): dataset = dataset_ops.Dataset.range(10) iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( dataset, [self._devices[1], self._devices[2]]) for _ in range(5): el0, el1 = next(iterator) queue.enqueue(el0) queue.enqueue(el1) fn() for i in range(10): self.assertEqual(queue.dequeue().numpy(), i) @combinations.generate(test_base.eager_only_combinations()) def testFunctionError(self): # In this test we verify that a function that raises an error ends up # properly deallocating the iterator resource. queue = data_flow_ops.FIFOQueue(10, dtypes.int64) queue.enqueue(0) def init_fn(n): return n def next_fn(_): ds = dataset_ops.Dataset.range(0) return next(iter(ds)) def finalize_fn(n): queue.enqueue(0) return n @def_function.function def fn(): dataset = dataset_ops._GeneratorDataset( 1, init_fn, next_fn, finalize_fn, output_signature=tensor_spec.TensorSpec([], dtypes.int64)) iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( dataset, [self._devices[1], self._devices[2]]) next(iterator) with self.assertRaises(errors.OutOfRangeError): fn() self.assertEqual(queue.size().numpy(), 2) @combinations.generate(test_base.eager_only_combinations()) def testMultipleInitializations(self): dataset = dataset_ops.Dataset.range(1000) for _ in range(5): multi_device_iterator = ( multi_device_iterator_ops.OwnedMultiDeviceIterator( dataset, [self._devices[1], self._devices[2]])) for i, el in enumerate(multi_device_iterator): self.assertEqual([i * 2, i * 2 + 1], [el[0].numpy(), el[1].numpy()]) @combinations.generate(test_base.eager_only_combinations()) def testLimitedRetracing(self): trace_count = [0] @def_function.function def f(iterator): trace_count[0] += 1 counter = np.int64(0) for _ in range(5): elem = next(iterator) counter += elem[0] counter += elem[1] return counter dataset = dataset_ops.Dataset.range(10) dataset2 = dataset_ops.Dataset.range(20) for _ in range(10): multi_device_iterator = ( multi_device_iterator_ops.OwnedMultiDeviceIterator( dataset, [self._devices[1], self._devices[2]])) self.assertEqual(self.evaluate(f(multi_device_iterator)), 45) multi_device_iterator2 = ( multi_device_iterator_ops.OwnedMultiDeviceIterator( dataset2, [self._devices[1], self._devices[2]])) self.assertEqual(self.evaluate(f(multi_device_iterator2)), 45) self.assertEqual(trace_count[0], 1) @combinations.generate(test_base.eager_only_combinations()) def testMissingDevices(self): dataset = dataset_ops.Dataset.range(1000) with self.assertRaisesRegex(ValueError, "`devices` must be provided."): multi_device_iterator_ops.OwnedMultiDeviceIterator(dataset) @combinations.generate(test_base.eager_only_combinations()) def testMissingInput(self): with self.assertRaisesRegex( ValueError, "When `dataset` is not provided, both `components` and `element_spec` " "must be specified."): multi_device_iterator_ops.OwnedMultiDeviceIterator( dataset=None, devices=[self._devices[1], self._devices[2]]) @combinations.generate(test_base.eager_only_combinations()) def testExtraElementSpecInput(self): dataset = dataset_ops.Dataset.range(1000) with self.assertRaisesRegex( ValueError, "When `dataset` is provided, `element_spec` and `components` must " "not be specified."): multi_device_iterator_ops.OwnedMultiDeviceIterator( dataset, devices=[self._devices[1], self._devices[2]], element_spec=dataset.element_spec) @combinations.generate(test_base.graph_only_combinations()) def testGraphMode(self): dataset = dataset_ops.Dataset.range(1000) with self.assertRaisesRegex( RuntimeError, "OwnedMultiDeviceIterator is only supported inside of tf.function or " "when eager execution is enabled."): multi_device_iterator_ops.OwnedMultiDeviceIterator( dataset, devices=[self._devices[1], self._devices[2]])
class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( skip=[0, 5, 10], take=[1], error=[None], error_msg=[None]) + combinations.combine(skip=[100], take=[1], error=[errors.InvalidArgumentError], error_msg=["Dataset was empty."]) + combinations.combine( skip=[0], take=[2], error=[errors.InvalidArgumentError], error_msg=["Dataset had more than one element."]))) def testGetSingleElement(self, skip, take, error=None, error_msg=None): def make_sparse(x): x_1d = array_ops.reshape(x, [1]) x_2d = array_ops.reshape(x, [1, 1]) return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d) dataset = dataset_ops.Dataset.range(100).skip(skip).map( lambda x: (x * x, make_sparse(x))).take(take) if error is None: dense_val, sparse_val = self.evaluate( get_single_element.get_single_element(dataset)) self.assertEqual(skip * skip, dense_val) self.assertAllEqual([[skip]], sparse_val.indices) self.assertAllEqual([skip], sparse_val.values) self.assertAllEqual([skip], sparse_val.dense_shape) else: with self.assertRaisesRegexp(error, error_msg): self.evaluate(get_single_element.get_single_element(dataset)) @combinations.generate(test_base.default_test_combinations()) def testWindow(self): """Test that `get_single_element()` can consume a nested dataset.""" def flat_map_func(ds): batched = ds.batch(2) element = get_single_element.get_single_element(batched) return dataset_ops.Dataset.from_tensors(element) dataset = dataset_ops.Dataset.range(10).window(2).flat_map( flat_map_func) self.assertDatasetProduces(dataset, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) @combinations.generate(test_base.default_test_combinations()) def testSideEffect(self): counter_var = variables.Variable(0) def increment_fn(x): counter_var.assign_add(1) return x def dataset_fn(): return dataset_ops.Dataset.range(1).map(increment_fn) @function.defun def fn(): _ = get_single_element.get_single_element(dataset_fn()) return "hello" self.evaluate(counter_var.initializer) self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(counter_var), 1) @combinations.generate(test_base.default_test_combinations()) def testAutomaticControlDependencies(self): counter_var = variables.Variable(1) def increment_fn(x): counter_var.assign(counter_var + 1) return x def multiply_fn(x): counter_var.assign(counter_var * 2) return x def dataset1_fn(): return dataset_ops.Dataset.range(1).map(increment_fn) def dataset2_fn(): return dataset_ops.Dataset.range(1).map(multiply_fn) @function.defun def fn(): _ = get_single_element.get_single_element(dataset1_fn()) _ = get_single_element.get_single_element(dataset2_fn()) return "hello" self.evaluate(counter_var.initializer) self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(counter_var), 4)
class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, parameterized.TestCase): def setUp(self): super(AutoShardDatasetTest, self).setUp() self._num_files = 10 self._num_records = 10 self.test_filenames = self._createFiles() def getAllDatasetElements(self, dataset): actual = [] next_fn = self.getNext(dataset) while True: try: actual.append(self.evaluate(next_fn())) except errors.OutOfRangeError: break return actual def assertDatasetProducesWithShuffle(self, dataset, expected, batch, num_examples, shuffle): if shuffle: actual = [] next_fn = self.getNext(dataset) for _ in range(num_examples): elem = self.evaluate(next_fn()) if isinstance(elem, tuple): actual.extend(elem) else: actual.extend(elem.tolist()) self.assertCountEqual(actual, expected) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_fn()) else: self.assertDatasetProduces(dataset, list(chunk(expected, batch))) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testFlatMapReaderPipeline(self, shuffle): dataset = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=shuffle) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (3, 8) for r in range(0, 10) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(batch_size=[1, 3, 10]))) def testDatasetOfReaderDatasetsPipeline(self, batch_size): # This tests a scenario where a list_files main return multiple files # due to the glob containing wildcards. def batch(iterator, n): l = len(iterator) for i in range(0, l, n): yield iterator[i:min(i + n, l)] datasets = [] for files in batch(self.test_filenames, batch_size): datasets.append( dataset_ops.Dataset.list_files(files, shuffle=False).map( core_readers.TFRecordDataset)) dataset = dataset_ops.Dataset.from_tensor_slices(datasets) dataset = dataset.flat_map(lambda x: x) # Simulate additional ops in between flat_map and interleave. This should be # a no-op since if ShardDataset is placed right after flat_map, we will only # have two datasets left at this point. dataset = dataset.prefetch(1) dataset = dataset.prefetch(1) dataset = dataset.interleave( lambda x: x, cycle_length=1, num_parallel_calls=1) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testZipReaderPipeline(self): dataset1 = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=False) dataset1 = dataset1.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset2 = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=False) dataset2 = dataset2.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ (b"Record %d of file %d" % (r, f), b"Record %d of file %d" % (r, f)) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProduces(dataset, expected) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testConcatenateReaderPipeline(self, shuffle): dataset1 = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=shuffle) dataset1 = dataset1.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset1 = dataset1.batch(5) dataset2 = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=shuffle) dataset2 = dataset2.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset2 = dataset2.batch(5) dataset = dataset1.concatenate(dataset2) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] expected += expected self.assertDatasetProducesWithShuffle(dataset, expected, 5, 8, shuffle) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testPipelineWithMap(self, shuffle): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate(test_base.default_test_combinations()) def testDirectFilenameTFRecordReaderPipeline(self): dataset = core_readers.TFRecordDataset(self.test_filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testValidPipelineWithRangeDataset(self, shuffle): dataset = dataset_ops.Dataset.range(self._num_files) dataset = dataset.map(lambda n: string_ops.string_join( # pylint:disable=g-long-lambda [self.get_temp_dir(), string_ops.string_format("/tf_record.{}.txt", [n])])) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(params=[(1, 0, 10, 10), (2, 1, 20, 5), (10, 1, 1, 10)]))) def testStandardReaderPipeline(self, params): num_epochs, index, batch_size, parallel_reads = params dataset = readers.make_tf_record_dataset( file_pattern=self.test_filenames, num_epochs=num_epochs, batch_size=batch_size, parser_fn=None, num_parallel_reads=parallel_reads, drop_final_batch=True, shuffle=False) dataset = distribute._AutoShardDataset(dataset, 2, index) outputs = self.getNext(dataset) self._verify_records( outputs, batch_size=batch_size, file_index=[i for i in range(index, self._num_records, 2)], num_epochs=num_epochs, interleave_cycle_length=parallel_reads, drop_final_batch=True, use_parser_fn=None) with self.assertRaises(errors.OutOfRangeError): self.evaluate(outputs()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testSampleResNetPipeline(self, shuffle): dataset = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=shuffle) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ distribute_options.AutoShardPolicy.DATA, distribute_options.AutoShardPolicy.AUTO ]))) def testShardByDataBeforePrefetch(self, sharding_policy): dataset = dataset_ops.Dataset.range(4) dataset = dataset.apply(testing.assert_next(["Shard", "Prefetch"])) dataset = dataset.prefetch(1) options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = sharding_policy dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 2, 0) self.assertDatasetProduces(dataset, [0, 2]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.times(combinations.combine( sharding_policy=[distribute_options.AutoShardPolicy.DATA, distribute_options.AutoShardPolicy.FILE]), combinations.combine(shuffle=[True, False])))) def testReplicateAndShardProduceDisjointData(self, shuffle, sharding_policy): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=shuffle) dataset = dataset.flat_map(core_readers.TFRecordDataset) graph_def = dataset._as_serialized_graph( strip_device_assignment=True, external_state_policy=distribute_options.ExternalStatePolicy.WARN) options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = sharding_policy ds1 = distribute._RemoteDataset(graph_def, "/device:CPU:0", dataset.element_spec) ds2 = distribute._RemoteDataset(graph_def, "/device:CPU:0", dataset.element_spec) ds1 = ds1.with_options(options) ds2 = ds2.with_options(options) ds1 = distribute._AutoShardDataset(ds1, 2, 0) ds2 = distribute._AutoShardDataset(ds2, 2, 1) elems1 = set(self.getAllDatasetElements(ds1)) elems2 = set(self.getAllDatasetElements(ds2)) self.assertEmpty(elems1.intersection(elems2)) @combinations.generate(test_base.default_test_combinations()) def testWorkersGreaterThanNumFilesWithDataSharding(self): options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = ( distribute_options.AutoShardPolicy.DATA) dataset = core_readers._TFRecordDataset(self.test_filenames) dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 5, 0) # Should return "Record (0,5) of file (0 --> 9)" since we are sharding by # individual elements, we should be able to get some data from all files. expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testAutoshardPolicyOff(self): options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = ( distribute_options.AutoShardPolicy.OFF) dataset = core_readers._TFRecordDataset(self.test_filenames) dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 5, 0) # Should return every record in every file since autosharding is turned off. expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testFileShardingWithoutReaderDatasetOp(self): options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = ( distribute_options.AutoShardPolicy.FILE) dataset = dataset_ops.Dataset.range(1024) dataset = dataset.with_options(options) # We are specifying that we want a file sharding policy, and this pipeline # doesn't start with file reading, so we should error out. with self.assertRaises(errors.NotFoundError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)()) @combinations.generate(test_base.default_test_combinations()) def testWorkersGreaterThanNumFiles(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 500, 499) self.assertDatasetProduces(dataset, []) @combinations.generate(test_base.default_test_combinations()) def testTFRecordReaderWithDirectFileNames(self): # Using `_TFRecordDataset` creates a raw op rather than wrapping it around # a flat_map automatically. dataset = core_readers._TFRecordDataset(self.test_filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testTFRecordReaderWithDirectFileNamesAndShapes(self): # Using `_TFRecordDataset` creates a raw op rather than wrapping it around # a flat_map automatically. dataset = core_readers._TFRecordDataset(self.test_filenames) # BatchDataset contains `output_types` and `output_shapes` dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 2, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 5) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5))) @combinations.generate(test_base.default_test_combinations()) def testShardOutOfRange(self): dataset = dataset_ops.Dataset.range(5) with self.assertRaises(errors.InvalidArgumentError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)()) @combinations.generate(test_base.default_test_combinations()) def testShardOutOfRangeEmptyDataset(self): dataset = dataset_ops.Dataset.range(0) with self.assertRaises(errors.OutOfRangeError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)()) @combinations.generate(test_base.default_test_combinations()) def testNoReaderPipelines(self): dataset = dataset_ops.Dataset.range(1024) dataset = distribute._AutoShardDataset(dataset, 2, 0) self.assertDatasetProduces(dataset, [i for i in range(1024) if i % 2 == 0]) @combinations.generate(test_base.default_test_combinations()) def testUnknownOpInPipelineStillShardsAtTheEnd(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.apply(unique.unique()) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testInvalidWorkerIndex(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) with self.assertRaises(errors.InvalidArgumentError): dataset = distribute._AutoShardDataset(dataset, 2, 2) self.evaluate(self.getNext(dataset)()) @combinations.generate(test_base.default_test_combinations()) def testAssertCardinality(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = dataset.apply(cardinality.assert_cardinality(42)) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5))) @combinations.generate(test_base.default_test_combinations()) def testMaxIntraOpParallelism(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5))) @combinations.generate(test_base.default_test_combinations()) def testPrivateThreadpool(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5))) @combinations.generate(test_base.default_test_combinations()) def testMakeBatchedFeaturesDataset(self): files = 2 records_per_file = 5 def make_record(file_index): example = example_pb2.Example( features=feature_pb2.Features( feature={ "file": feature_pb2.Feature( int64_list=feature_pb2.Int64List(value=[file_index])), })) return example.SerializeToString() filenames = [] for file_index in range(files): filename = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % file_index) filenames.append(filename) writer = python_io.TFRecordWriter(filename) for _ in range(records_per_file): writer.write(make_record(file_index)) writer.close() dataset = readers.make_batched_features_dataset( file_pattern=filenames, batch_size=records_per_file, features={ "file": parsing_ops.FixedLenFeature([], dtypes.int64), }, reader=core_readers.TFRecordDataset, num_epochs=1) # We should shard at the file level, so that all records come from file 0. dataset = distribute._AutoShardDataset(dataset, 2, 0) dataset = dataset.unbatch() output = self.getDatasetOutput(dataset) files = [elem["file"] for elem in output] self.assertEqual(files, [0] * records_per_file)
class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.graph_only_combinations()) def testNoGradients(self): component = constant_op.constant([1.]) side = constant_op.constant(0.) add = lambda x: x + side dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) value = dataset_ops.make_one_shot_iterator(dataset).get_next() self.assertIsNone(gradients_impl.gradients(value, component)[0]) self.assertIsNone(gradients_impl.gradients(value, side)[0]) self.assertIsNone( gradients_impl.gradients(value, [component, side])[0]) @combinations.generate(test_base.graph_only_combinations()) def testCapturingStateInOneShotRaisesException(self): var = variables.Variable(37.0, name="myvar") dataset = (dataset_ops.Dataset.from_tensor_slices( [0.0, 1.0, 2.0]).map(lambda x: x + var)) with self.assertRaisesRegex( ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support " "datasets that capture stateful objects.+myvar"): dataset_ops.make_one_shot_iterator(dataset) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIterator(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(components).map( _map_fn).repeat(14)) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorCaptureByValue(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) tensor_components = tuple( [ops.convert_to_tensor(c) for c in components]) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(tensor_components).map( _map_fn).repeat(14)) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.default_test_combinations()) def testOneShotIteratorInsideContainer(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def within_container(): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square( z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(components).map( _map_fn).repeat(14)) return iterator.get_next() server = server_lib.Server.create_local_server() # Create two iterators within unique containers, and run them to # make sure that the resources aren't shared. # # The test below would fail if cname were the same across both # sessions. for j in range(2): with session.Session(server.target) as sess: cname = "iteration%d" % j with ops.container(cname): get_next = within_container() for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip( components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorNonBlocking(self): dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) iterator = dataset_ops.make_one_shot_iterator(dataset) next_element = iterator.get_next() # Create a session with a single thread to ensure that the # one-shot iterator initializer does not deadlock. config = config_pb2.ConfigProto(inter_op_parallelism_threads=1, use_per_session_threads=True) with session.Session(config=config) as sess: self.assertAllEqual([1, 4, 9], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) # Test with multiple threads invoking the one-shot iterator concurrently. with session.Session(config=config) as sess: results = [] def consumer_thread(): try: results.append(sess.run(next_element)) except errors.OutOfRangeError: results.append(None) num_threads = 8 threads = [ self.checkedThread(consumer_thread) for _ in range(num_threads) ] for t in threads: t.start() for t in threads: t.join() self.assertLen(results, num_threads) self.assertLen([None for r in results if r is None], num_threads - 1) self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorInitializerFails(self): # Define a dataset whose initialization will always fail. dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4])) iterator = dataset_ops.make_one_shot_iterator(dataset) next_element = iterator.get_next() with self.cached_session() as sess: with self.assertRaisesRegex(errors.InvalidArgumentError, ""): sess.run(next_element) # Test that subsequent attempts to use the iterator also fail. with self.assertRaisesRegex(errors.InvalidArgumentError, ""): sess.run(next_element) with self.cached_session() as sess: def consumer_thread(): with self.assertRaisesRegex(errors.InvalidArgumentError, ""): sess.run(next_element) num_threads = 8 threads = [ self.checkedThread(consumer_thread) for _ in range(num_threads) ] for t in threads: t.start() for t in threads: t.join() @combinations.generate(test_base.graph_only_combinations()) def testSimpleSharedResource(self): components = (np.array(1, dtype=np.int64), np.array([1, 2, 3], dtype=np.int64), np.array(37.0, dtype=np.float64)) server = server_lib.Server.create_local_server() # Create two non-overlapping sessions that share the same iterator # resource on the same server, and verify that an action of the # first session (initializing the iterator) is visible in the # second session. with ops.Graph().as_default(): iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensors(components).map( lambda x, y, z: (x, y, z)), shared_name="shared_iterator") init_op = iterator.initializer get_next = iterator.get_next() with session.Session(server.target) as sess: sess.run(init_op) results = sess.run(get_next) for component, result_component in zip(components, results): self.assertAllEqual(component, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Re-initialize the iterator in the first session. sess.run(init_op) with ops.Graph().as_default(): # Re-define the iterator manually, without defining any of the # functions in this graph, to ensure that we are not # accidentally redefining functions with the same names in the # new graph. iterator = iterator_ops.Iterator.from_structure( shared_name="shared_iterator", output_types=(dtypes.int64, dtypes.int64, dtypes.float64), output_shapes=([], [3], [])) get_next = iterator.get_next() with session.Session(server.target) as sess: # Use the iterator without re-initializing in the second session. results = sess.run(get_next) for component, result_component in zip(components, results): self.assertAllEqual(component, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testNotInitializedError(self): components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensors(components)) get_next = iterator.get_next() with self.cached_session() as sess: with self.assertRaisesRegex(errors.FailedPreconditionError, "iterator has not been initialized"): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testReinitializableIterator(self): dataset_3 = dataset_ops.Dataset.from_tensors( constant_op.constant([1, 2, 3])) dataset_4 = dataset_ops.Dataset.from_tensors( constant_op.constant([4, 5, 6, 7])) iterator = iterator_ops.Iterator.from_structure( dataset_ops.get_legacy_output_types(dataset_3), [None]) dataset_3_init_op = iterator.make_initializer(dataset_3) dataset_4_init_op = iterator.make_initializer(dataset_4) get_next = iterator.get_next() self.assertEqual(dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(dataset_ops.get_legacy_output_types(dataset_4), dataset_ops.get_legacy_output_types(iterator)) self.assertEqual( [None], dataset_ops.get_legacy_output_shapes(iterator).as_list()) with self.cached_session() as sess: # The iterator is initially uninitialized. with self.assertRaises(errors.FailedPreconditionError): sess.run(get_next) # Initialize with one dataset. sess.run(dataset_3_init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Initialize with a different dataset. sess.run(dataset_4_init_op) self.assertAllEqual([4, 5, 6, 7], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Reinitialize with the first dataset. sess.run(dataset_3_init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testReinitializableIteratorWithFunctions(self): def g(): for i in range(10): yield i iterator = iterator_ops.Iterator.from_structure(dtypes.int64, []) next_element = iterator.get_next() with self.cached_session() as sess: dataset_1 = dataset_ops.Dataset.from_generator( g, output_types=dtypes.int64) sess.run(iterator.make_initializer(dataset_1)) for expected in range(10): self.assertEqual(expected, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) dataset_2 = dataset_ops.Dataset.from_generator( g, output_types=dtypes.int64) sess.run(iterator.make_initializer(dataset_2)) for expected in range(10): self.assertEqual(expected, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) @combinations.generate(test_base.default_test_combinations()) def testReinitializableIteratorStaticErrors(self): # Non-matching structure for types and shapes. with self.assertRaises(TypeError): iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), [None]) # Test validation of dataset argument. iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64)) # Incompatible structure. with self.assertRaises(ValueError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( ((constant_op.constant([1, 2, 3], dtype=dtypes.int64), ), (constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64), )))) # Incompatible types. with self.assertRaises(TypeError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( (constant_op.constant([1, 2, 3], dtype=dtypes.int32), constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32)))) # Incompatible shapes. iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), ([None], [])) with self.assertRaises(TypeError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( (constant_op.constant([1, 2, 3], dtype=dtypes.int64), constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64)))) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandle(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) next_element = feedable_iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(dataset_3), dataset_ops.get_structure(feedable_iterator))) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual( 10, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 1, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 20, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 2, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 30, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 3, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 40, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle}) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleFuture(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) next_element = feedable_iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(dataset_3), dataset_ops.get_structure(feedable_iterator))) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual( 10, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 1, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 20, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 2, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 30, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 3, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 40, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle}) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleReuseTensorObject(self): dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset) initializable_iterator = dataset_ops.make_initializable_iterator( dataset) structure_iterator = iterator_ops.Iterator.from_structure( dataset_ops.get_legacy_output_types(dataset)) created_ops = len(ops.get_default_graph().get_operations()) self.assertIs(one_shot_iterator.string_handle(), one_shot_iterator.string_handle()) self.assertIs(initializable_iterator.string_handle(), initializable_iterator.string_handle()) self.assertIs(structure_iterator.string_handle(), structure_iterator.string_handle()) # Assert that getting the (default) string handle creates no ops. self.assertEqual(created_ops, len(ops.get_default_graph().get_operations())) # Specifying an explicit name will create a new op. handle_with_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo", handle_with_name.op.name) self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name) handle_with_same_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo_1", handle_with_same_name.op.name) self.assertIsNot(handle_with_name, handle_with_same_name) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleError(self): dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices( [1, 2, 3]).repeat()) dataset_float_vector = (dataset_ops.Dataset.from_tensors( [1.0, 2.0, 3.0])) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_int_scalar = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, []) feedable_int_vector = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, [None]) feedable_int_any = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32) with self.cached_session() as sess: handle_int_scalar = sess.run( dataset_ops.make_one_shot_iterator( dataset_int_scalar).string_handle()) handle_float_vector = sess.run( dataset_ops.make_one_shot_iterator( dataset_float_vector).string_handle()) self.assertEqual( 1, sess.run(feedable_int_scalar.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) self.assertEqual( 2, sess.run(feedable_int_any.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) with self.assertRaises(errors.InvalidArgumentError): print( sess.run(feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) with self.assertRaises(errors.InvalidArgumentError): print( sess.run( feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_float_vector})) @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpDirectSession(self): worker_config = config_pb2.ConfigProto() worker_config.device_count["CPU"] = 3 with ops.device("/job:localhost/replica:0/task:0/cpu:1"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_3_handle = iterator_3.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( h, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) return remote_iterator.get_next() with ops.device("/job:localhost/replica:0/task:0/cpu:0"): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) remote_op = functional_ops.remote_call(args=[iterator_3_handle], Tout=[dtypes.int32], f=_remote_fn, target=target_placeholder) with self.session(config=worker_config) as sess: elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [1]) # Fails when target is cpu:2 where the resource is not located. with self.assertRaises(errors.InvalidArgumentError): sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" }) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [2]) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self): s1 = server_lib.Server.create_local_server() s2 = server_lib.Server.create_local_server() s3 = server_lib.Server.create_local_server() cluster_def = cluster_pb2.ClusterDef() workers = cluster_def.job.add() workers.name = "worker" workers.tasks[0] = s1.target[len("grpc://"):] workers.tasks[1] = s2.target[len("grpc://"):] client = cluster_def.job.add() client.name = "client" client.tasks[0] = s3.target[len("grpc://"):] config = config_pb2.ConfigProto(cluster_def=cluster_def) worker_devices = [ "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2) ] itr_handles = [] for device in worker_devices: with ops.device(device): src = dataset_ops.Dataset.from_tensor_slices([device]) itr = dataset_ops.make_one_shot_iterator(src) itr_handles.append(itr.string_handle()) targets = dataset_ops.Dataset.from_tensor_slices(worker_devices) handles = dataset_ops.Dataset.from_tensor_slices(itr_handles) @function.Defun(dtypes.string) def loading_func(h): remote_itr = iterator_ops.Iterator.from_string_handle( h, dataset_ops.get_legacy_output_types(itr), dataset_ops.get_legacy_output_shapes(itr)) return remote_itr.get_next() def map_fn(target, handle): return functional_ops.remote_call(args=[handle], Tout=[dtypes.string], f=loading_func, target=target) with ops.device("/job:client"): client_dataset = dataset_ops.Dataset.zip( (targets, handles)).map(map_fn) itr = dataset_ops.make_initializable_iterator(client_dataset) n = itr.get_next() with session.Session(s3.target, config=config) as sess: sess.run(itr.initializer) expected_values = worker_devices for expected in expected_values: self.assertEqual((compat.as_bytes(expected), ), sess.run(n)) with self.assertRaises(errors.OutOfRangeError): sess.run(n) @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/job:localhost/replica:0/task:0/cpu:0"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_3_handle = iterator_3.string_handle() def _encode_raw(byte_array): return bytes(bytearray(byte_array)) @function.Defun(dtypes.uint8) def _remote_fn(h): handle = script_ops.py_func(_encode_raw, [h], dtypes.string) remote_iterator = iterator_ops.Iterator.from_string_handle( handle, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) return remote_iterator.get_next() with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) iterator_3_handle_uint8 = parsing_ops.decode_raw( input_bytes=iterator_3_handle, out_type=dtypes.uint8) remote_op = functional_ops.remote_call( args=[iterator_3_handle_uint8], Tout=[dtypes.int32], f=_remote_fn, target=target_placeholder) with self.cached_session() as sess: elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [1]) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [2]) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) @combinations.generate(test_base.graph_only_combinations()) def testRepeatedGetNextWarning(self): iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.range(10)) warnings.simplefilter("always") with warnings.catch_warnings(record=True) as w: for _ in range(100): iterator.get_next() self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, len(w)) for warning in w: self.assertIn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE, str(warning.message)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( expected_element_structure=tensor_spec.TensorSpec( [], dtypes.float32), expected_output_classes=ops.Tensor, expected_output_types=dtypes.float32, expected_output_shapes=[[]]))) def testTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): tf_value_fn = lambda: constant_op.constant(37.0) tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( expected_element_structure=sparse_tensor.SparseTensorSpec( [1], dtypes.int32), expected_output_classes=sparse_tensor.SparseTensor, expected_output_types=dtypes.int32, expected_output_shapes=[[1]]))) def testSparseTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): def tf_value_fn(): return sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(expected_element_structure={ "a": tensor_spec.TensorSpec([], dtypes.float32), "b": (tensor_spec.TensorSpec([1], dtypes.string), tensor_spec.TensorSpec([], dtypes.string)) }, expected_output_classes={ "a": ops.Tensor, "b": (ops.Tensor, ops.Tensor) }, expected_output_types={ "a": dtypes.float32, "b": (dtypes.string, dtypes.string) }, expected_output_shapes={ "a": [], "b": ([1], []) }))) def testNestedTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): def tf_value_fn(): return { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) } tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator)) @combinations.generate(test_base.default_test_combinations()) def testIteratorGetNextName(self): with ops.Graph().as_default(): iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(37.0)) next_element = iterator.get_next(name="overridden_name") self.assertEqual("overridden_name", next_element.op.name) @combinations.generate( combinations.combine(tf_api_version=[1, 2], mode="eager", execution_mode=[context.ASYNC, context.SYNC])) def testIteratorEagerIteration(self, execution_mode): with context.eager_mode(), context.execution_mode(execution_mode): val = 0 dataset = dataset_ops.Dataset.range(10) iterator = iter(dataset) for foo in iterator: self.assertEqual(val, foo.numpy()) val += 1 @combinations.generate(test_base.eager_only_combinations()) def testOwnedIteratorFunction(self): queue = data_flow_ops.FIFOQueue(10, dtypes.int64) @def_function.function def fn(): dataset = dataset_ops.Dataset.range(10) iterator = iter(dataset) for _ in range(10): queue.enqueue(next(iterator)) fn() for i in range(10): self.assertEqual(queue.dequeue().numpy(), i) @combinations.generate(test_base.eager_only_combinations()) def testOwnedIteratorFunctionError(self): # In this test we verify that a function that raises an error ends up # properly deallocating the iterator resource. queue = data_flow_ops.FIFOQueue(10, dtypes.int64) queue.enqueue(0) def init_fn(n): return n def next_fn(_): ds = dataset_ops.Dataset.range(0) return next(iter(ds)) def finalize_fn(n): queue.enqueue(0) return n @def_function.function def fn(): output_signature = tensor_spec.TensorSpec((), dtypes.int64) dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn, output_signature) iterator = iter(dataset) next(iterator) with self.assertRaises(errors.OutOfRangeError): fn() self.assertEqual(queue.size().numpy(), 2) @combinations.generate(test_base.eager_only_combinations()) def testLimitedRetracing(self): trace_count = [0] @def_function.function def f(iterator): trace_count[0] += 1 counter = np.int64(0) for elem in iterator: counter += elem return counter dataset = dataset_ops.Dataset.range(5) dataset2 = dataset_ops.Dataset.range(10) for _ in range(10): self.assertEqual(self.evaluate(f(iter(dataset))), 10) self.assertEqual(self.evaluate(f(iter(dataset2))), 45) self.assertEqual(trace_count[0], 1) @combinations.generate(test_base.eager_only_combinations()) def testNestedFunctionsIteratorResource(self): @def_function.function def sum_dataset(ds): it = iter(ds) @def_function.function def next_element(it): return next(it) total = 0 for _ in range(10): total += next_element(it) return total ds = dataset_ops.Dataset.range(10) self.assertEqual(sum_dataset(ds).numpy(), 45) self.assertEqual(sum_dataset(ds).numpy(), 45) @combinations.generate(test_base.default_test_combinations()) def testNestedAutomaticControlDependencies(self): counter_var = variables.Variable(0) def map_fn(x): counter_var.assign_add(1) return x def dataset_fn(): return dataset_ops.Dataset.range(10).map(map_fn) @def_function.function def fn(): it = iter(dataset_fn()) for _ in range(10): _ = next(it) return counter_var self.evaluate(counter_var.initializer) self.assertEqual(self.evaluate(fn()), 10)
class AutoShardWithRebatchDatasetTest( reader_dataset_ops_test_base.TFRecordDatasetTestBase, parameterized.TestCase): def _setUpFiles(self, num_files, num_records_per_file): self._num_files = num_files self._num_records = num_records_per_file self.test_filenames = self._createFiles() @combinations.generate(test_base.default_test_combinations()) def testFileShardingWithLegacyRebatch(self): # Tests that RebatchDatasetV1 is a passthrough op. self._setUpFiles(num_files=5, num_records_per_file=10) dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [[self._record(3, i)] for i in range(10)] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testFileShardingWithRebatch(self): # Tests that RebatchDatasetV2 is a passthrough op. self._setUpFiles(num_files=3, num_records_per_file=5) dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._RebatchDataset(dataset, batch_sizes=[2, 1, 2]) dataset = distribute._AutoShardDataset(dataset, 3, 1) expected = [[self._record(1, 0), self._record(1, 1)], [self._record(1, 2)], [self._record(1, 3), self._record(1, 4)]] self.assertDatasetProduces(dataset, expected) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.times( combinations.combine(sharding_policy=[ distribute_options.AutoShardPolicy.DATA, distribute_options.AutoShardPolicy.AUTO ]), combinations.combine(with_prefetch=[True, False])))) def testUseLegacyRebatchWithDataSharding(self, sharding_policy, with_prefetch): # This test simulates a distributed environment with 3 workers, each with # 1 replica. dataset = dataset_ops.Dataset.range(8) dataset = dataset.batch(4) options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = sharding_policy dataset = dataset.with_options(options) # We expect the auto-shard rewrite to rewrite RebatchDatasetV2 to # RebatchDataset(V1) for correctness reasons. This will modify the output # of the dataset. worker_a_dataset = distribute._RebatchDataset( dataset, batch_sizes=[2, 1, 1]) if with_prefetch: worker_a_dataset = worker_a_dataset.prefetch(1) worker_a_dataset = distribute._AutoShardDataset( worker_a_dataset, 3, 0, num_replicas=3) expected = [[0, 1], [4, 5]] self.assertDatasetProduces(worker_a_dataset, expected) worker_b_dataset = distribute._RebatchDataset( dataset, batch_sizes=[1, 1, 2]) if with_prefetch: worker_b_dataset = worker_b_dataset.prefetch(1) worker_b_dataset = distribute._AutoShardDataset( worker_b_dataset, 3, 1, num_replicas=3) expected = [[2, 3], [6, 7]] self.assertDatasetProduces(worker_b_dataset, expected) worker_c_dataset = distribute._RebatchDataset( dataset, batch_sizes=[1, 2, 1]) if with_prefetch: worker_c_dataset = worker_c_dataset.prefetch(1) worker_c_dataset = distribute._AutoShardDataset( worker_c_dataset, 3, 2, num_replicas=3) expected = [[], []] self.assertDatasetProduces(worker_c_dataset, expected)
class LocalWorkersTest(data_service_test_base.TestBase, parameterized.TestCase): """Tests reading from local workers if `target_workers` is `local`.""" @combinations.generate(test_base.default_test_combinations()) def testOneLocalWorker(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=1, num_remote_workers=5) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="local") self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testLocalWorkers(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") self.assertDatasetProduces(ds, num_local_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testRepeatedDataset(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_local_workers * num_repetitions * list(range(num_elements)), assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testPrefetchingDataset(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") ds = ds.prefetch(10) self.assertDatasetProduces(ds, expected_output=num_local_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testMultipleEpochs(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") for _ in range(10): self.assertDatasetProduces(ds, num_local_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testDynamicSharding(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 100 ds = self.make_distributed_range_dataset( num_elements, cluster, processing_mode=ShardingPolicy.DYNAMIC, target_workers="LOCAL") self.assertDatasetProduces(ds, list(range(num_elements)), assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testEmptyDataset(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 0 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") self.assertDatasetProduces(ds, []) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testNonLocalRead(self, num_local_workers, num_remote_workers): """This test ensures the remote workers are running and producing data.""" cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="any") num_workers = num_local_workers + num_remote_workers self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testNoLocalWorker(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=0, num_remote_workers=3) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") with self.assertRaisesRegex( errors.InvalidArgumentError, "Local reads require local tf.data workers, but no local worker is " "found."): self.getDatasetOutput(ds) @combinations.generate(test_base.default_test_combinations()) def testInconsistentTargetWorkers(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=3, num_remote_workers=3) ds = dataset_ops.Dataset.range(10) datasets = [ self.make_distributed_dataset(ds, cluster, job_name="test_job", target_workers=target_workers) for target_workers in ["AUTO", "ANY", "LOCAL"] ] with self.assertRaisesRegex( errors.InvalidArgumentError, "but there is already an existing job with that name using " "target_workers <AUTO>."): for dataset in datasets: self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testAnonymousJobWithDifferentTargetWorkers(self): num_local_workers, num_remote_workers = (3, 3) cluster = multi_process_cluster.MultiProcessCluster( num_local_workers, num_remote_workers) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) datasets = { target_workers: self.make_distributed_dataset(ds, cluster, target_workers=target_workers) for target_workers in ["AUTO", "ANY", "LOCAL"] } num_workers = num_local_workers + num_remote_workers self.assertDatasetProduces(datasets["AUTO"], num_workers * list(range(num_elements)), assert_items_equal=True) self.assertDatasetProduces(datasets["ANY"], num_workers * list(range(num_elements)), assert_items_equal=True) self.assertDatasetProduces(datasets["LOCAL"], num_local_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testCoordinatedRead(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=3, num_remote_workers=3) ds = dataset_ops.Dataset.range(10).repeat() ds = self.make_distributed_dataset(ds, cluster, job_name="test_job", consumer_index=0, num_consumers=3, target_workers="LOCAL") with self.assertRaisesRegex( errors.InvalidArgumentError, "Coordinated reads require non-local workers"): self.getDatasetOutput(ds)
class ShuffleDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase, parameterized.TestCase): def _build_shuffle_dataset( self, range_limit=10, num_repeats=5, buffer_size=5, seed=None, reshuffle_each_iteration=None, ): return dataset_ops.Dataset.range(range_limit).shuffle( buffer_size, seed=seed, reshuffle_each_iteration=reshuffle_each_iteration).repeat( num_repeats) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(reshuffle_each_iteration=[True, False], buffer_size=[1, 3, 5, 8, 10]))) def testShuffleCore(self, reshuffle_each_iteration, buffer_size): seed = 55 range_limit = 5 num_repeats = 2 num_outputs = range_limit * num_repeats # pylint: disable=g-long-lambda self.run_core_tests( lambda: self._build_shuffle_dataset(range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, seed=seed, reshuffle_each_iteration= reshuffle_each_iteration), num_outputs) @combinations.generate( combinations.combine(tf_api_version=1, mode=["graph"], reshuffle_each_iteration=[True, False], buffer_size=[1, 3, 5, 8, 10])) def testMultipleIterators(self, reshuffle_each_iteration, buffer_size): range_limit = 5 num_repeats = 2 num_outputs = range_limit * num_repeats def ds_fn(): # pylint: disable=cell-var-from-loop return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, seed=None, # Iterator seeds are generated non-deterministically. reshuffle_each_iteration=reshuffle_each_iteration) # pylint: enable=cell-var-from-loop with ops.Graph().as_default() as g: ds = ds_fn() iterators = [ ds.make_one_shot_iterator(), ds.make_one_shot_iterator() ] get_next_ops = [it.get_next() for it in iterators] saveables = [ contrib_iterator_ops.make_saveable_from_iterator(it) for it in iterators ] for saveable in saveables: ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) saver = saver_lib.Saver(allow_empty=True) with self.session(graph=g) as sess: self._save(sess, saver) expected = [ self.evaluate(get_next_ops) for _ in range(num_outputs) ] self._restore(saver, sess) actual = [ self.evaluate(get_next_ops) for _ in range(num_outputs) ] self.match(expected, actual)
class MapAndBatchCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testNumParallelBatches(self, verify_fn, drop_remainder): range_size = 11 num_shards = 3 num_repeats = 2 batch_size = 5 num_parallel_batches = 2 total_outputs = (range_size // num_shards) * num_repeats if drop_remainder: num_outputs = total_outputs // batch_size else: num_outputs = int(math.ceil(total_outputs / batch_size)) def build_ds(range_start, drop_remainder): def _map_fn(x): return math_ops.square(x) return dataset_ops.Dataset.range( range_start, range_start + range_size).shard( num_shards=num_shards, index=0).repeat(num_repeats).apply( batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, num_parallel_batches=num_parallel_batches, drop_remainder=drop_remainder)) verify_fn(self, lambda: build_ds(10, drop_remainder=drop_remainder), num_outputs) @combinations.generate( combinations.times( test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testNumParallelCalls(self, verify_fn, drop_remainder): range_size = 11 num_shards = 3 num_repeats = 2 batch_size = 5 num_parallel_calls = 7 total_outputs = (range_size // num_shards) * num_repeats if drop_remainder: num_outputs = total_outputs // batch_size else: num_outputs = int(math.ceil(total_outputs / batch_size)) def build_ds(range_start, drop_remainder=False): def _map_fn(x): return math_ops.square(x) return dataset_ops.Dataset.range( range_start, range_start + range_size).shard( num_shards=num_shards, index=0).repeat(num_repeats).apply( batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, num_parallel_calls=num_parallel_calls, drop_remainder=drop_remainder)) verify_fn(self, lambda: build_ds(10, drop_remainder=drop_remainder), num_outputs) @combinations.generate( combinations.times(test_base.default_test_combinations(), checkpoint_test_base.default_test_combinations())) def testSparse(self, verify_fn): def build_dataset(): def map_fn(i): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) return dataset_ops.Dataset.range(10).apply( batching.map_and_batch(map_fn, 5)) verify_fn(self, build_dataset, num_outputs=2)
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), data_service_test_base.all_cluster_configurations()) ) def testDistributeBasic(self, work_dir, fault_tolerant_mode): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(compression=[None, "AUTO"]))) def testDistributeCompression(self, compression): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, compression=compression) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testDistributeInvalidCompression(self): cluster = data_service_test_base.TestCluster(num_workers=1) with self.assertRaisesRegex(ValueError, "Invalid compression argument"): self.make_distributed_range_dataset(10, cluster, compression="foo") @combinations.generate(test_base.eager_only_combinations()) def testDistributeSparse(self): cluster = data_service_test_base.TestCluster(num_workers=1) element = sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) ds = dataset_ops.Dataset.from_tensors(element) ds = self.make_distributed_dataset(ds, cluster) results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds] self.assertAllEqual(results, [[0]]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeRagged(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8]) ds = ds.map(math_ops.range) ds = ds.apply(batching.dense_to_ragged_batch(2)) ds = self.make_distributed_dataset(ds, cluster) results = [elem.to_tensor() for elem in ds] self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]]) self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]]) self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( init_source=["textfile", "keyvaluetensor", "dataset"]))) def testDistributeLookupTable(self, init_source): cluster = data_service_test_base.TestCluster(num_workers=1) initializer = self.lookupTableInitializer(init_source, [10, 11]) table = lookup_ops.StaticHashTable(initializer, -1) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.evaluate(lookup_ops.tables_initializer()) self.assertDatasetProduces(ds, [10, 11, -1], requires_initialization=True) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(value_rank=[0, 1]))) def testDistributeMutableHashTable(self, value_rank): def value(v): for _ in range(value_rank): v = [v, v] return v v1 = value(10) v2 = value(11) default_value = value(-1) cluster = data_service_test_base.TestCluster(num_workers=1) table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64, default_value) self.evaluate(table.insert([0, 1], [v1, v2])) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [v1, v2, default_value], requires_initialization=True) @combinations.generate(test_base.default_test_combinations()) def testDifferentShuffleOrders(self): random_seed.set_random_seed(None) num_elements = 100 cluster = data_service_test_base.TestCluster(num_workers=2) ds = dataset_ops.Dataset.range(num_elements) ds = ds.shuffle(num_elements) ds = self.make_distributed_dataset(ds, cluster) output = self.getDatasetOutput(ds) # The output will be two sequences of range(num_elements) # non-deterministically interleaved together. If the orders of the elements # were the same, first_order and second_order computed below will be equal. first_order = {} second_order = {} for element in output: if element in first_order: second_order[element] = len(second_order) else: first_order[element] = len(first_order) self.assertNotEqual(first_order, second_order) @combinations.generate(test_base.default_test_combinations()) def testMultipleEpochs(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) for _ in range(10): self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testRepeatedDataset(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster) ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_repetitions * list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testConcurrentEpoch(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 num_datasets = 3 get_nexts = [] results = [] for _ in range(num_datasets): ds = self.make_distributed_range_dataset(num_elements, cluster) get_nexts.append(self.getNext(ds)) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = self.evaluate(get_nexts[dataset_ind]()) results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result) @combinations.generate(test_base.default_test_combinations()) def testMultiWorker(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testMaxOutstandingRequests(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, max_outstanding_requests=1) self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testInsideFunction(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 @def_function.function def f(): ds = self.make_distributed_range_dataset(num_elements, cluster) result = tensor_array_ops.TensorArray(dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in ds: result = result.write(i, elem) i += 1 return result.stack() result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) @combinations.generate(test_base.default_test_combinations()) def testSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 1000 def make_ds(): return dataset_ops.Dataset.range(num_elements).shuffle( num_elements) ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") get_next_1 = self.getNext(ds1) get_next_2 = self.getNext(ds2) results = [] for _ in range(num_elements // 5): results.append(self.evaluate(get_next_1())) results.append(self.evaluate(get_next_2())) results += self.getIteratorOutput(get_next_1) results += self.getIteratorOutput(get_next_2) self.assertCountEqual(list(range(num_elements)), results) @combinations.generate(test_base.default_test_combinations()) def testDifferentJobNames(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name1") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) # iteration 2 self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds1, []) @combinations.generate(test_base.default_test_combinations()) def testSharedJobNameRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 num_repetitions = 3 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds1 = ds1.repeat(num_repetitions) ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] get_next_1 = self.getNext(ds1) get_next_2 = self.getNext(ds2) for _ in range((num_elements * num_repetitions) // 5): results.append(self.evaluate(get_next_1())) for _ in range((num_elements * num_repetitions) // 5): results.append(self.evaluate(get_next_2())) results += self.getIteratorOutput(get_next_1) results += self.getIteratorOutput(get_next_2) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(job_name=[None, "test"]))) def testGcUnusedJob(self, job_name): cluster = data_service_test_base.TestCluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster, job_name=job_name) it = iter(ds) self.assertEqual(next(it).numpy(), 0) self.assertEqual(cluster.workers[0].num_tasks(), 1) del it while cluster.workers[0].num_tasks() > 0: time.sleep(0.1) @combinations.generate(test_base.eager_only_combinations()) def testDontGcUsedJob(self): cluster = data_service_test_base.TestCluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 10 it1 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test1")) it2 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) self.assertEqual(cluster.workers[0].num_tasks(), 2) del it1 del it2 # Check that only the first job is gced. The second job will not be gced # because there is still an outstanding iterator for it. while cluster.workers[0].num_tasks() > 1: time.sleep(0.1) self.assertEqual(cluster.workers[0].num_tasks(), 1) @combinations.generate(test_base.default_test_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) cluster = data_service_test_base.TestCluster(num_workers=1) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10) opts = dataset_ops.Options() opts.experimental_deterministic = False ds = ds.with_options(opts) ds = self.make_distributed_dataset(ds, cluster) return ds self.checkDeterminism(dataset_fn=dataset_fn, expect_determinism=False, expected_elements=elements) def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = dataset_ops.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) cluster = data_service_test_base.TestCluster(num_workers=3) ds = self.make_distributed_dataset(ds, cluster) self.getDatasetOutput(ds) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(external_state_policy=[ distribute_options.ExternalStatePolicy.IGNORE, distribute_options.ExternalStatePolicy.WARN ]))) def testStatefulNoError(self, external_state_policy): self.run_stateful(external_state_policy) @combinations.generate(test_base.default_test_combinations()) def testStatefulError(self): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(distribute_options.ExternalStatePolicy.FAIL) @combinations.generate(test_base.default_test_combinations()) def testDistributeFromInterleave(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(2) def interleave_fn(_): dataset = dataset_ops.Dataset.range(2) self.make_distributed_dataset(dataset, cluster) return dataset ds = ds.interleave(interleave_fn, cycle_length=2) self.assertDatasetProduces(ds, [0, 0, 1, 1]) @combinations.generate(test_base.default_test_combinations()) def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "service must be a string"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=1)) @combinations.generate(test_base.default_test_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "service must not be empty"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="")) @combinations.generate(test_base.default_test_combinations()) def testDistributeExplicitProtocol(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(10) ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="grpc://" + cluster.dispatcher_address())) self.assertDatasetProduces(ds, list(range(10))) @combinations.generate(test_base.default_test_combinations()) def testDistributeInvalidProtocol(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex( errors.NotFoundError, "No credentials factory has been registered for protocol grp"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="grp://" + cluster.dispatcher_address())) self.getDatasetOutput(ds) @combinations.generate(test_base.eager_only_combinations()) def testDistributeInvalidProcessingMode(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "invalid is not a valid processing mode"): ds = ds.apply( data_service_ops.distribute(processing_mode="invalid", service="grpc://localhost:5000")) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentProcessingModesDatasets(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs") ds = dataset_ops.Dataset.zip((ds1, ds2)) self.assertDatasetProduces(ds, list( zip(range(num_elements), range(num_elements))), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentProcessingModesDatasetsSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch", job_name="job_name") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs", job_name="job_name") ds = dataset_ops.Dataset.zip((ds1, ds2)) with self.assertRaisesRegex(errors.FailedPreconditionError, "but there is already an existing job"): self.getDatasetOutput(ds) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdMultipleComponents(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"]) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdWrongElementSpec(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, wrong_spec) with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdNotRegistered(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, element_spec) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 cluster = data_service_test_base.TestCluster(num_workers=1) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = self.make_distributed_dataset(ds, cluster) ds = ds.prefetch(1) get_next = self.getNext(ds) self.assertEqual(0, self.evaluate(get_next())) # Without properly implemented cancellation, we will hang here while trying # to garbage collect the dataset iterator. @combinations.generate(test_base.default_test_combinations()) def testRegisterEquivalentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1) id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2) self.assertEqual(self.evaluate(id_1), self.evaluate(id_2)) @combinations.generate(test_base.default_test_combinations()) def testRegisterDifferentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(20) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1) id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2) self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2)) @combinations.generate(test_base.default_test_combinations()) def testTwoLevelDistribute(self): cluster_1_size = 3 cluster_1 = data_service_test_base.TestCluster( num_workers=cluster_1_size) cluster_2 = data_service_test_base.TestCluster(num_workers=1) num_sizes = 10 size_repeats = 5 strings = ["a" * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(strings) ds = ds.shuffle(len(strings)) ds = self.make_distributed_dataset(ds, cluster_1) # Large enough so that all strings of the same size are windowed together. window_size = cluster_1_size * size_repeats batch_size = size_repeats def key_func(x): return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64) ds = ds.apply( grouping.group_by_window( key_func=key_func, reduce_func=lambda _, x: x.batch(batch_size), window_size=window_size)) ds = self.make_distributed_dataset(ds, cluster_2) get_next = self.getNext(ds) for _ in range(num_sizes): element = self.evaluate(get_next()) for _ in range(1, cluster_1_size): self.assertAllEqual(self.evaluate(get_next()), element) self.assertEmpty(self.getIteratorOutput(get_next)) @combinations.generate( combinations.times(test_base.default_test_combinations())) def testDistributeLargeGraph(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [tensor])
class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testBasic(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testCanHandleUnknownRank(self): dataset = dataset_ops.Dataset.from_tensors("xxx") # decode_image results in a tensor of completely unknown shape (i.e. unknown # rank) dataset = dataset.map(image_ops.decode_image) self.assertEqual([tensor_shape.TensorShape(None)], _flat_shapes(dataset)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual([tensor_shape.TensorShape(None)], _flat_shapes(rebatched_dataset)) @combinations.generate(test_base.default_test_combinations()) def testCanHandleUnknownDims(self): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.batch(10, drop_remainder=False) dataset = dataset.batch(10, drop_remainder=False) self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(dataset)]) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @combinations.generate(test_base.default_test_combinations()) def testScalarInputError(self): dataset = dataset_ops.Dataset.range(1024) distribute._RebatchDataset(dataset.batch(4), num_replicas=4) with self.assertRaisesRegexp(ValueError, "at least one dimension"): distribute._RebatchDataset(dataset, num_replicas=4) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testBatchNotDivisibleByNumReplicas(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [] i = 0 for _ in range(32): # number of steps # first four minibatches have seven elements for _ in range(4): expected_output.append([k for k in range(i, i + 7)]) i += 7 # last minibatch has four elements expected_output.append([k for k in range(i, i + 4)]) i += 4 self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testBatchSizeNotDivisibleByNumReplicas2(self): dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) # This will rebatch into sub-batches of size 4, since # ceil(16 / 5) = 4. However, that means only the first 4 replicas will get # data. expected_output = [[k for k in range(i, i + 4)] for i in range(0, 16, 4)] expected_output.extend([[]]) # Last replica gets an empty batch expected_output.extend( [[k for k in range(i, i + 4)] for i in range(16, 32, 4)]) expected_output.extend([[]]) # Last replica gets an empty batch self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testTupleOutput(self): dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [([k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension [k for k in range(i, i + 8)]) for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testNestedDictionaryOutput(self): dataset = dataset_ops.Dataset.range(1024).map( lambda x: {"a": x, "b": {"c": x}}).batch(32) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [{"a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension "b": {"c": [k for k in range(i, i + 8)]}} for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testFinalPartialBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(1032).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # if drop_remainder, the final partial batch is dropped, even though it # makes up a complete minibatch. expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension if not drop_remainder: # The last partial batch of size 8 is split over 4 replicas expected_output.extend( [[k for k in range(i, i + 2)] for i in range(1024, 1032, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testFinalPartialBatchAfterRebatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(34).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] # pylint: disable=g-complex-comprehension if not drop_remainder: # The last partial batch of size 2 is split over 4 replicas expected_output += [[32], [33], [], []] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testMultipleBatches(self): dataset = dataset_ops.Dataset.range(128).batch(4).batch(8) self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Each element is a list of 8 elements where each element is a list of 4. expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4)] # generates 8 elements for i in range(0, 128, 32)] self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, 4) self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Each element is a list of 2 elements where each element is a list of 4. expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4)] # generates 2 elements for i in range(0, 128, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatch(self): dataset = dataset_ops.Dataset.range(1024).apply( batching.map_and_batch(math_ops.square, 32)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testMapAndBatchWithCapturedInput(self): captured_t = variables.Variable(42) dataset = dataset_ops.Dataset.range(1024).apply( batching.map_and_batch(lambda x: captured_t, 32)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[42 for _ in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 1024, 8)] self.evaluate(variables.global_variables_initializer()) self.assertDatasetProduces( rebatched_dataset, expected_output, requires_initialization=True) @combinations.generate(test_base.default_test_combinations()) def testPaddedBatch(self): dataset = dataset_ops.Dataset.range(128).batch( 4, drop_remainder=True).padded_batch( 8, padded_shapes=[5]) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Each element is a list of 8 elements in which each element is a list of 5 # elements, first four are numbers and the last one is a padded zero. expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4)] # generates 8 elements for i in range(0, 128, 32)] self.assertDatasetProduces(dataset, expected_output) self.assertEqual([[None, 5]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Each element is a list of 2 elements in which each element is a list of 5 # elements, first four are numbers and the last one is a padded zero. expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4)] # generates 2 elements for i in range(0, 128, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testConcatenate(self): dataset1 = dataset_ops.Dataset.range(64).batch(8) dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset1.concatenate(dataset2) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = ([[i, i + 1] for i in range(0, 64, 2)] + [[i, i + 1] for i in range(0, 32, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testConcatenateDifferentShapes(self): dataset1 = dataset_ops.Dataset.range(64).batch(16) dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset1.concatenate(dataset2) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] + [[i, i + 1] for i in range(0, 32, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testZip(self): dataset1 = dataset_ops.Dataset.range(64).batch(8) dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None], [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentShapes(self): dataset1 = dataset_ops.Dataset.range(64).batch(16) dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None], [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1]) for i in range(0, 32, 2)] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testFlatMapBatching(self): dataset = dataset_ops.Dataset.range(2).flat_map( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32)) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Two elements where each element is a list of 4 elements where each element # is a list of 8. expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for _ in range(2) for i in range(0, 32, 8)] # generates 4 elements self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testInterleaveBatching(self): dataset = dataset_ops.Dataset.range(2).interleave( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32), cycle_length=2) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] expected_output += expected_output self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testParallelInterleaveBatching(self): dataset = dataset_ops.Dataset.range(2).interleave( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32), cycle_length=2, num_parallel_calls=2) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] expected_output += expected_output self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testGroupByWindowStaticBatch(self): dataset = dataset_ops.Dataset.from_tensor_slices( [[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)]) reduce_fn = lambda bucket_id, ds: ds.batch( # pylint: disable=g-long-lambda batch_size=10) dataset = dataset.apply( grouping.group_by_window( key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=2) self.assertEqual([[None, 3]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # pylint: disable=g-complex-comprehension expected_output = [[[j + i * 4 + k * 20] * 3 for i in range(5)] for j in range(4) for k in range(2)] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testGroupByWindowDynamicBatch(self): # {0, 1, 0, 1, ...} dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) def reduce_fn(key, ds): # key == 0 -> .batch(5) # key == 1 -> .batch(10) return ds.batch(batch_size=(key + 1) * 5) dataset = dataset.apply( grouping.group_by_window( key_func=lambda x: x, reduce_func=reduce_fn, window_size=10)) dataset = distribute._RebatchDataset(dataset, num_replicas=2) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and # the batches of 10 (value == 1) split into minibatches of (5, 5) # [(batch_size, value), ...] pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1)] pairs = pairs * 2 expected_output = [[value] * batch_size for batch_size, value in pairs] self.assertDatasetProduces(dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testGroupByWindowDynamicBatchWithPartialBatch(self): # {0, 1, 0, 1, ...} dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) def reduce_fn(key, ds): # key == 0 -> .batch(5) # key == 1 -> .batch(10) return ds.batch(batch_size=(key + 1) * 5) dataset = dataset.apply( grouping.group_by_window( key_func=lambda x: x, reduce_func=reduce_fn, window_size=11)) dataset = distribute._RebatchDataset(dataset, num_replicas=2) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (0, 0), (5, 1), (5, 1), (1, 1), (0, 1), (3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)] expected_output = [[value] * batch_size for batch_size, value in pairs] self.assertDatasetProduces(dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self): # This test exercises nested batch functionality, dynamic batch size # and drop_remainder=True together. dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2) def reduce_fn(key, ds): # key == 0 -> .batch(5) # key == 1 -> .batch(10) return ds.batch(batch_size=(key + 1) * 5, drop_remainder=True) dataset = dataset.apply( grouping.group_by_window( key_func=lambda x: x, reduce_func=reduce_fn, window_size=11)) dataset = distribute._RebatchDataset(dataset, num_replicas=2) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and # the batches of 10 (value == 1) split into minibatches of (5, 5) # [(batch_size, value), ...] pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1), (3, 0), (2, 0)] expected_output = [[value] * batch_size for batch_size, value in pairs] self.assertDatasetProduces(dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testScanAfterBatch(self): dataset = dataset_ops.Dataset.range(40).batch(10).apply( scan_ops.scan(np.int64(2), lambda state, value: (state, value * state))) dataset = distribute._RebatchDataset(dataset, num_replicas=2) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testMakeBatchedFeaturesDataset(self): # Set up fn = os.path.join(self.get_temp_dir(), "tf_record.txt") writer = python_io.TFRecordWriter(fn) for i in range(1024): writer.write( example_pb2.Example( features=feature_pb2.Features( feature={ "value": feature_pb2.Feature( int64_list=feature_pb2.Int64List(value=[i])) })).SerializeToString()) writer.close() dataset = readers.make_batched_features_dataset( file_pattern=fn, batch_size=32, features={"value": parsing_ops.FixedLenFeature([], dtypes.int64)}, shuffle=False, num_epochs=1, drop_final_batch=False) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [{ "value": [k for k in range(i, i + 8)] } for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testRaggedTensorDataset(self): # Set up a dataset that produces ragged tensors with a static batch size. row_lengths = np.random.randint(8, size=128) values = np.random.normal(size=np.sum(row_lengths)).astype(np.float32) dataset = dataset_ops.Dataset.from_tensor_slices( ragged_tensor.RaggedTensor.from_row_lengths(values, row_lengths)) dataset = dataset.batch(32, drop_remainder=True) # The map changes the internal representation of the ragged tensor. # This test will fail if we don't normalize the tensor representation. dataset = dataset.map(lambda x: x) dataset = distribute._RebatchDataset(dataset, num_replicas=8) # After rebatching, batch size is now 4. expected_output = [] value_index = 0 for batch_row_lengths in row_lengths.reshape((-1, 4)): num_values = np.sum(batch_row_lengths) expected_output.append( ragged_tensor.RaggedTensor.from_row_lengths( values[value_index:(value_index + num_values)], batch_row_lengths)) value_index += num_values self.assertDatasetProduces(dataset, expected_output)
class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testBasic(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual( [[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testCanHandleUnknownRank(self): dataset = dataset_ops.Dataset.from_tensors("xxx") # decode_image results in a tensor of completely unknown shape (i.e. unknown # rank) dataset = dataset.map(image_ops.decode_image) self.assertEqual([tensor_shape.TensorShape(None)], _flat_shapes(dataset)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual([tensor_shape.TensorShape(None)], _flat_shapes(rebatched_dataset)) @combinations.generate(test_base.default_test_combinations()) def testCanHandleUnknownDims(self): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.batch(10, drop_remainder=False) dataset = dataset.batch(10, drop_remainder=False) self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(dataset)]) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual( [[None, None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @combinations.generate(test_base.default_test_combinations()) def testScalarInputError(self): dataset = dataset_ops.Dataset.range(1024) distribute._RebatchDataset(dataset.batch(4), num_replicas=4) with self.assertRaisesRegex(ValueError, ("You can fix the issue " "by adding the `batch`")): distribute._RebatchDataset(dataset, num_replicas=4) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testBatchNotDivisibleByNumReplicas(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) self.assertEqual( [[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [] i = 0 for _ in range(32): # number of steps # first four minibatches have seven elements for _ in range(4): expected_output.append([k for k in range(i, i + 7)]) i += 7 # last minibatch has four elements expected_output.append([k for k in range(i, i + 4)]) i += 4 self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testBatchSizeNotDivisibleByNumReplicas2(self): dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) # This will rebatch into sub-batches of size 4, since # ceil(16 / 5) = 4. However, that means only the first 4 replicas will get # data. expected_output = [[k for k in range(i, i + 4)] for i in range(0, 16, 4)] expected_output.extend([[]]) # Last replica gets an empty batch expected_output.extend([[k for k in range(i, i + 4)] for i in range(16, 32, 4)]) expected_output.extend([[]]) # Last replica gets an empty batch self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testTupleOutput(self): dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch( 32) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [ ( [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension [k for k in range(i, i + 8)]) for i in range(0, 1024, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testNestedDictionaryOutput(self): dataset = dataset_ops.Dataset.range(1024).map(lambda x: { "a": x, "b": { "c": x } }).batch(32) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [ { "a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension "b": { "c": [k for k in range(i, i + 8)] } } for i in range(0, 1024, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testFinalPartialBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(1032).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual( [[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # if drop_remainder, the final partial batch is dropped, even though it # makes up a complete minibatch. expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension if not drop_remainder: # The last partial batch of size 8 is split over 4 replicas expected_output.extend([[k for k in range(i, i + 2)] for i in range(1024, 1032, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testFinalPartialBatchAfterRebatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(34).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual( [[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] # pylint: disable=g-complex-comprehension if not drop_remainder: # The last partial batch of size 2 is split over 4 replicas expected_output += [[32], [33], [], []] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testMultipleBatches(self): dataset = dataset_ops.Dataset.range(128).batch(4).batch(8) self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Each element is a list of 8 elements where each element is a list of 4. expected_output = [ [ [j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4) ] # generates 8 elements for i in range(0, 128, 32) ] self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, 4) self.assertEqual( [[None, None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Each element is a list of 2 elements where each element is a list of 4. expected_output = [ [ [j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4) ] # generates 2 elements for i in range(0, 128, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testRaggedTensorDataset(self): # Set up a dataset that produces ragged tensors with a static batch size. row_lengths = np.random.randint(8, size=128) values = np.random.normal(size=np.sum(row_lengths)).astype(np.float32) dataset = dataset_ops.Dataset.from_tensor_slices( ragged_tensor.RaggedTensor.from_row_lengths(values, row_lengths)) dataset = dataset.batch(32, drop_remainder=True) # The map changes the internal representation of the ragged tensor. # This test will fail if we don't normalize the tensor representation. dataset = dataset.map(lambda x: x) dataset = distribute._RebatchDataset(dataset, num_replicas=8) # After rebatching, batch size is now 4. expected_output = [] value_index = 0 for batch_row_lengths in row_lengths.reshape((-1, 4)): num_values = np.sum(batch_row_lengths) expected_output.append( ragged_tensor.RaggedTensor.from_row_lengths( values[value_index:(value_index + num_values)], batch_row_lengths)) value_index += num_values self.assertDatasetProduces(dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testNoOutputShapes(self): # Some datasets, e.g. datasets with None tensors, have components without # output shapes. Test that this doesn't break rebatching shape inference # logic. dataset = dataset_ops.Dataset.range(4) dataset = dataset.map(lambda x: (x, None)) dataset = dataset.batch(4, drop_remainder=True) _ = distribute._RebatchDataset(dataset, num_replicas=2)
class CacheCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def setUp(self): super(CacheCheckpointTest, self).setUp() self.range_size = 10 self.num_repeats = 3 self.num_outputs = self.range_size * self.num_repeats self.cache_file_prefix = "test" def make_dataset_fn(self, is_memory): if is_memory: filename = "" else: filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix) def ds_fn(): return dataset_ops.Dataset.range( self.range_size).cache(filename).repeat(self.num_repeats) return ds_fn def expected_outputs(self): return list(range(self.range_size)) * self.num_repeats @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointBeforeOneEpoch(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Generate 5 entries from iterator and save checkpoint. outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs(ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Generate 8 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs(ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, range(8)) outputs = outputs[:5] outputs.extend( self.gen_outputs(ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointAfterOneEpoch(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Generate 15 entries from iterator and save checkpoint. outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs(ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Generate 18 entries from iterator but save checkpoint after producing 15. outputs = self.gen_outputs(ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) outputs = list(range(10)) + list(range(5)) + self.gen_outputs( ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Generate 13 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs(ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) # Since we ran for more than one epoch, the cache was completely written. # The ckpt was saved when the iterator was in cache-write mode. Test that # the iterator falls back to read mode after restoring if the cache has # been completely written. outputs = list(range(5)) + self.gen_outputs(ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointUnusedWriterIterator(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Checkpoint before get_next is called even once. outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False) self.assertSequenceEqual(outputs, []) outputs = self.gen_outputs(ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointUnusedMidwayWriterIterator(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Produce 5 elements and checkpoint. outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint, then produce no elements and checkpoint. outputs.extend( self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce rest of the elements. outputs.extend( self.gen_outputs(ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, list(range(10)) * 3) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testUnusedCheckpointError(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Produce 5 elements and save ckpt. outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) if is_memory: outputs = self.gen_outputs(ds_fn, [], self.num_outputs, verify_exhausted=False) self.assertSequenceEqual(outputs, self.expected_outputs()) else: # Since the complete cache has not been written, a new iterator which does # not restore the checkpoint will throw an error since there is a partial # cache shard. with self.assertRaises(errors.AlreadyExistsError): outputs = self.gen_outputs(ds_fn, [], self.num_outputs, verify_exhausted=False) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testIgnoreCheckpointIfCacheWritten(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Produce 15 elements and save ckpt. This will write the complete cache. outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Build the iterator again but do not restore from ckpt. Since the cache # has already been written we should be able to use it. outputs = self.gen_outputs(ds_fn, [], self.num_outputs, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3)
class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, parameterized.TestCase): def setUp(self): super(SnapshotDatasetTest, self).setUp() self.removeTFRecords() def removeTFRecords(self): for filename in self.test_filenames: os.remove(filename) self.test_filenames = [] def setUpTFRecord(self, num_files=10, num_records=10): self._num_files = num_files self._num_records = num_records self.test_filenames = self._createFiles() def makeSnapshotDirectory(self): tmpdir = self.get_temp_dir() tmpdir = os.path.join(tmpdir, "snapshot") os.mkdir(tmpdir) return tmpdir def assertSnapshotDirectoryContains(self, directory, num_fingerprints, num_runs_per_fp, num_snapshot_files): dirlist = os.listdir(directory) self.assertLen(dirlist, num_fingerprints) for i in range(num_fingerprints): fingerprint_dir = os.path.join(directory, dirlist[i]) fingerprint_dir_list = sorted(os.listdir(fingerprint_dir)) self.assertLen(fingerprint_dir_list, num_runs_per_fp + 1) self.assertEqual(fingerprint_dir_list[num_runs_per_fp], "snapshot.metadata") for j in range(num_runs_per_fp): run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j]) run_dirlist = sorted(os.listdir(run_dir)) self.assertLen(run_dirlist, num_snapshot_files) file_counter = 0 for filename in run_dirlist: self.assertEqual(filename, "%08d.snapshot" % file_counter) file_counter += 1 @combinations.generate(test_base.default_test_combinations()) def testWriteDifferentPipelinesInOneDirectory(self): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1000))) dataset = dataset_ops.Dataset.range(1001) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1001))) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotMultipleSimultaneous(self): tmpdir = self.makeSnapshotDirectory() dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.snapshot(tmpdir)) next1 = self.getNext(dataset1) dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) next2 = self.getNext(dataset2) for _ in range(1000): self.evaluate(next1()) self.evaluate(next2()) # we check that only one copy of the metadata has been written, and the # one that lost the race would be in passthrough mode. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testGetNextCreatesDir(self): tmpdir = self.makeSnapshotDirectory() # We create two iterators but call getNext on only one. dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.snapshot(tmpdir)) next1 = self.getNext(dataset1) dataset2 = dataset_ops.Dataset.range(1001) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) _ = self.getNext(dataset2) for _ in range(1000): self.evaluate(next1()) # We check that only one directory is created. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP ]))) def testWriteSnapshotSimpleSuccessful(self, compression): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply( snapshot.snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset, list(range(1000))) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotRepeatAfterwards(self): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(snapshot.snapshot(tmpdir)) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP ]))) def testReadSnapshotBackAfterWrite(self, compression): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset2, expected) @combinations.generate(test_base.default_test_combinations()) def testReadShuffledSnapshotAfterWrite(self): self.setUpTFRecord(num_files=10, num_records=50) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 50) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, shuffle_on_read=True)) next2 = self.getNext(dataset2) res1 = self.evaluate(next2()) res2 = self.evaluate(next2()) res3 = self.evaluate(next2()) res4 = self.evaluate(next2()) res5 = self.evaluate(next2()) # make sure that we don't read the file back in the same order. self.assertNotEqual([res1, res2, res3, res4, res5], expected[0:5]) # make sure all the elements are still there dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply( snapshot.snapshot(tmpdir, shuffle_on_read=True)) self.assertDatasetProduces(dataset3, expected, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testReadSnapshotParallelAfterWrite(self): self.setUpTFRecord(10, 4000) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 4000) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot(tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, reader_buffer_size=10)) self.assertDatasetProduces(dataset, expected, assert_items_equal=True) # remove the original files and try to read the data back only from # snapshot. self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, reader_buffer_size=10)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.times( combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP ]), combinations.combine(threads=2, size=[1, 2]) + combinations.combine(threads=8, size=[1, 4, 8])))) def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads, size): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot(tmpdir, compression=compression, num_writer_threads=threads, writer_buffer_size=size)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from # snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testSameFingerprintWithDifferentInitializationOrder(self): tmpdir = self.makeSnapshotDirectory() dataset1 = dataset_ops.Dataset.range(0, 100) dataset2 = dataset_ops.Dataset.range(100, 200) dataset3 = dataset_ops.Dataset.range(200, 300) dataset = dataset1.concatenate(dataset2).concatenate(dataset3) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(300))) dataset4 = dataset_ops.Dataset.range(200, 300) dataset5 = dataset_ops.Dataset.range(100, 200) dataset6 = dataset_ops.Dataset.range(0, 100) dataset = dataset6.concatenate(dataset5).concatenate(dataset4) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(300))) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testExpiredSnapshotRewrite(self): tmpdir = self.makeSnapshotDirectory() dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply( snapshot.snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) next1 = self.getNext(dataset1) # Don't finish reading dataset1, so it is never finalized for _ in range(500): self.evaluate(next1()) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) time.sleep(2) # Creating dataset2 after we run through dataset1 due to eager mode, where # the snapshot state is determined immediately upon dataset creation. We # only want to determine the snapshot state for dataset2 after the first # snapshot has expired. dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) next2 = self.getNext(dataset2) for _ in range(500): self.evaluate(next2()) self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1) @combinations.generate(test_base.default_test_combinations()) def testSpecifyShardSize(self): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.from_tensor_slices([1.0]) dataset = dataset.map( lambda x: gen_array_ops.broadcast_to(x, [1024, 1024])) dataset = dataset.repeat(10) dataset = dataset.apply( snapshot.snapshot(tmpdir, shard_size_bytes=10 * 1024 * 1024)) next_fn = self.getNext(dataset) for _ in range(10): self.evaluate(next_fn()) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 4) @combinations.generate(test_base.default_test_combinations()) def testAdditionalOperationsAfterReadBack(self): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset2, expected) expected_after = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply(snapshot.snapshot(tmpdir)) dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000)) self.assertDatasetProduces(dataset3, expected_after)
class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( count=20, size=[10, 14, 17], shift=[7, 14], stride=[1, 2, 6], drop_remainder=[True, False]) + combinations.combine( count=[0, 1], size=10, shift=4, stride=1, drop_remainder=[True, False]))) def testWindowDataset(self, count, size, shift, stride, drop_remainder): """Tests a dataset that slides a window its input elements.""" components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) def _flat_map_fn(x, y, z): return dataset_ops.Dataset.zip((x.batch(batch_size=size), y.batch(batch_size=size), z.batch(batch_size=size))) dataset = dataset_ops.Dataset.from_tensor_slices(components).map( _map_fn).repeat(count).window( size=size, shift=shift, stride=stride, drop_remainder=drop_remainder).flat_map(_flat_map_fn) get_next = self.getNext(dataset) self.assertEqual([[None] + list(c.shape[1:]) for c in components], [ts.as_list() for ts in nest.flatten( dataset_ops.get_legacy_output_shapes(dataset))]) num_full_batches = max(0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1) for i in range(num_full_batches): result = self.evaluate(get_next()) for component, result_component in zip(components, result): for j in range(size): self.assertAllEqual(component[(i * shift + j * stride) % 7]**2, result_component[j]) if not drop_remainder: num_partial_batches = (count * 7) // shift + ( (count * 7) % shift > 0) - num_full_batches for i in range(num_partial_batches): result = self.evaluate(get_next()) for component, result_component in zip(components, result): remaining = (count * 7) - ((num_full_batches + i) * shift) num_elements = remaining // stride + ((remaining % stride) > 0) for j in range(num_elements): self.assertAllEqual( component[((num_full_batches + i) * shift + j * stride) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(count=20, size=0, shift=3, stride=1) + combinations.combine(count=20, size=3, shift=0, stride=1) + combinations.combine(count=20, size=3, shift=3, stride=0))) def testWindowDatasetInvalid(self, count, size, shift, stride): with self.assertRaises(errors.InvalidArgumentError): ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window( size=size, shift=shift, stride=stride).flat_map(lambda x: x.batch(batch_size=size)) self.evaluate(ds._variant_tensor) @combinations.generate(test_base.default_test_combinations()) def testWindowDifferentNestedStructures(self): ds = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])).window(2) self.getNext(ds) ds = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2]}).window(2) self.getNext(ds) @combinations.generate(test_base.default_test_combinations()) def testWindowSparse(self): def _sparse(i): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) dataset = dataset_ops.Dataset.range(10).map(_sparse).window( size=5, shift=3, drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5)) num_batches = (10 - 5) // 3 + 1 expected_output = [ sparse_tensor.SparseTensorValue( indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4], dense_shape=[5, 1]) for i in range(num_batches) ] self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate(test_base.default_test_combinations()) def testWindowSparseWithDifferentDenseShapes(self): def _sparse(i): return sparse_tensor.SparseTensorValue( indices=array_ops.expand_dims( math_ops.range(i, dtype=dtypes.int64), 1), values=array_ops.fill([math_ops.cast(i, dtypes.int32)], i), dense_shape=[i]) dataset = dataset_ops.Dataset.range(10).map(_sparse).window( size=5, shift=3, drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5)) expected_output = [] num_batches = (10 - 5) // 3 + 1 for i in range(num_batches): expected_indices = [] expected_values = [] for j in range(5): for k in range(i * 3 + j): expected_indices.append([j, k]) expected_values.append(i * 3 + j) expected_output.append( sparse_tensor.SparseTensorValue( indices=expected_indices, values=expected_values, dense_shape=[5, i * 3 + 5 - 1])) self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate(test_base.default_test_combinations()) def testNestedWindowSparse(self): def _sparse(i): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) dataset = dataset_ops.Dataset.range(10).map(_sparse).window( size=4, shift=2, drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window( size=3, shift=1, drop_remainder=True).flat_map(lambda x: x.batch(batch_size=3)) expected_output = [ sparse_tensor.SparseTensorValue( indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7], dense_shape=[3, 4, 1]), sparse_tensor.SparseTensorValue( indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9], dense_shape=[3, 4, 1]) ] self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate(test_base.default_test_combinations()) def testWindowShapeError(self): def generator(): yield [1.0, 2.0, 3.0] yield [4.0, 5.0, 6.0] yield [7.0, 8.0, 9.0, 10.0] dataset = dataset_ops.Dataset.from_generator( generator, dtypes.float32, output_shapes=[None]).window( size=3, shift=1).flat_map(lambda x: x.batch(batch_size=3)) self.assertDatasetProduces( dataset, expected_error=( errors.InvalidArgumentError, r"Cannot batch tensors with different shapes in component 0. " r"First element had shape \[3\] and element 2 had shape \[4\].")) @combinations.generate(test_base.default_test_combinations()) def testWindowIgnoreErrors(self): input_values = np.float32([1., np.nan, 2., np.nan, 3.]) dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map( lambda x: array_ops.check_numerics(x, "message")).window( size=2, shift=2, stride=2, drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2)) self.assertDatasetProduces( dataset, expected_output=[np.float32([1., 2.]), np.float32([2., 3.])]) @combinations.generate(test_base.default_test_combinations()) def testNestedOutput(self): if not context.executing_eagerly(): self.skipTest("self.evaluate() does not work with a dataset") dataset = dataset_ops.Dataset.range(100) dataset = dataset_ops.Dataset.zip((dataset, dataset)).window(10) for i, nested_dataset in enumerate(dataset): x, y = nested_dataset self.assertDatasetProduces(x, range(i*10, (i+1)*10)) self.assertDatasetProduces(y, range(i*10, (i+1)*10)) @combinations.generate(test_base.default_test_combinations()) def testDropRemainderOutput(self): dataset = dataset_ops.Dataset.range(100) dataset = dataset.window(30, drop_remainder=True) dataset = dataset.flat_map(lambda x: x.batch(30)) dataset = dataset.batch(4) self.assertDatasetProduces( dataset, expected_output=[[[y + 30 * x for y in range(30)] for x in range(3)]]) @combinations.generate(test_base.default_test_combinations()) def testName(self): dataset = dataset_ops.Dataset.from_tensors(42).window( 1, name="window").flat_map(lambda x: x) self.assertDatasetProduces(dataset, [42])
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testDispatcherStop(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] results.append(next(iterator).numpy()) cluster.stop_dispatcher() # After the dispatcher dies, the worker should continue providing the rest # of the dataset's elements. for _ in range(num_elements - 1): results.append(next(iterator).numpy()) self.assertEqual(results, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartBeforeReading(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) cluster.restart_dispatcher() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartDuringReading(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.restart_dispatcher() for elem in iterator: results.append(elem.numpy()) self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartDuringDistributedEpoch(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset( num_elements, cluster, processing_mode="distributed_epoch") iterator = iter(ds) results = [] for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.restart_dispatcher() for elem in iterator: results.append(elem.numpy()) self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartDuringDistributedEpochRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 repetitions = 5 breakpoints = [50, 250, 450, 500] ds = dataset_ops.Dataset.range(num_elements) ds = ds.repeat(repetitions) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") iterator = iter(ds) results = [] for breakpoint in breakpoints: for _ in range(len(results), breakpoint): results.append(next(iterator).numpy()) cluster.restart_dispatcher() self.assertCountEqual(repetitions * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherRestartBetweenIterations(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(100, cluster) self.assertDatasetProduces(ds, list(range(num_elements))) cluster.restart_dispatcher() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherManyRestarts(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements_start = 10 num_elements_end = 15 datasets = [] for num_elements in range(num_elements_start, num_elements_end): datasets.append( self.make_distributed_range_dataset(num_elements, cluster)) cluster.restart_dispatcher() for ds, num_elements in zip( datasets, range(num_elements_start, num_elements_end)): self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherAndWorkerRestart(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) cluster.restart_dispatcher() cluster.workers[0].restart() self.assertDatasetProduces(ds, list(range(num_elements))) cluster.restart_dispatcher() cluster.workers[0].restart() self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(workers_to_add=[1, 3, 10]))) def testRoundRobinAddWorkers(self, workers_to_add): starting_workers = 3 cluster = data_service_test_base.TestCluster( num_workers=starting_workers) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_consumers = 7 ds = self.make_round_robin_dataset(cluster, num_consumers) get_next = self.getNext(ds, requires_initialization=True) results = [] zeros_seen = 0 for _ in range(25): results.append(self.evaluate(get_next())) if results[-1] == 0: zeros_seen += 1 for _ in range(workers_to_add): cluster.add_worker() # Read until all new workers have joined. while zeros_seen < starting_workers + workers_to_add: results.append(self.evaluate(get_next())) if results[-1] == 0: zeros_seen += 1 # Read some more. for _ in range(25): results.append(self.evaluate(get_next())) self.checkRoundRobinGroups(results, num_consumers) @combinations.generate(test_base.eager_only_combinations()) def testRoundRobinRestartWorker(self): num_workers = 3 # Set a shutdown quiet period to prevent workers from shutting down partway # through a round. cluster = data_service_test_base.TestCluster( num_workers, worker_shutdown_quiet_period_ms=2000) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_consumers = 5 ds = self.make_round_robin_dataset(cluster, num_consumers) get_next = self.getNext(ds, requires_initialization=True) results = [] self.read(get_next, results, 20) cluster.workers[1].stop() # Check that we can continue to read even with a worker stopped. self.read(get_next, results, 20) cluster.workers[1].restart() # Read until we get results from the restarted worker, then read some more. while results[-1] != 0: results.append(self.evaluate(get_next())) self.read(get_next, results, 20) self.checkRoundRobinGroups(results, num_consumers) @combinations.generate(test_base.eager_only_combinations()) def testRoundRobinMultiStartStop(self): num_workers = 3 # Set a shutdown quiet period to prevent workers from shutting down partway # through a round. cluster = data_service_test_base.TestCluster( num_workers, worker_shutdown_quiet_period_ms=2000) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_consumers = 5 ds = self.make_round_robin_dataset(cluster, num_consumers) get_next = self.getNext(ds, requires_initialization=True) results = [] self.read(get_next, results, 20) for i in range(num_workers): cluster.workers[i].stop() self.read(get_next, results, 20) cluster.workers[i].restart() self.read(get_next, results, 20) cluster.add_worker() cluster.restart_dispatcher() for i in range(num_workers): cluster.workers[i].stop() self.read(get_next, results, 20) self.checkRoundRobinGroups(results, num_consumers) @combinations.generate(test_base.eager_only_combinations()) def testDispatcherAndMultiWorkerRestart(self): num_workers = 2 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] cluster.restart_dispatcher() for worker_index in range(num_workers): cluster.workers[worker_index].restart() for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), results) cluster.restart_dispatcher() for worker_index in range(num_workers): cluster.workers[worker_index].restart() for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testStartServersLate(self): # Test that the data service client performs retries instead of failing when # the dataset is created before the master and worker are started. try: import portpicker # pylint: disable=g-import-not-at-top dispatcher_port = portpicker.pick_unused_port() except: raise self.skipTest( "Flakes in portpicker library do not represent " "TensorFlow errors.") cluster = data_service_test_base.TestCluster( num_workers=1, dispatcher_port=dispatcher_port, start=False) def start_servers(): time.sleep(0.5) cluster.start_dispatcher() cluster.start_workers() start_servers_thread = threading.Thread(target=start_servers, daemon=True) start_servers_thread.start() num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) start_servers_thread.join() @combinations.generate(test_base.eager_only_combinations()) def testAddWorkerMidJob(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.add_worker() # Wait for the new worker to register with the dispatcher. while cluster.num_registered_workers() < 2: time.sleep(10 / 1000) # 10ms for elem in iterator: results.append(elem.numpy()) self.assertCountEqual(2 * list(range(num_elements)), results) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(use_same_port=[True, False]), data_service_test_base.all_cluster_configurations()) ) def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): self.assertEqual(i, next(iterator).numpy()) # Stop the original worker and start a new one. cluster.workers[0].restart(use_same_port=use_same_port) # There may have been some elements prefetched from the first worker # before it was stopped. while True: val = next(iterator).numpy() if val == 0: break # The dataset starts over now that we read from the new worker. # TODO(b/157086991): Iterate until end of sequence when we support # detecting lost workers. for i in range(1, num_elements // 2): val = next(iterator).numpy() self.assertEqual(i, val) @combinations.generate(test_base.eager_only_combinations()) def testChangeProcessingModeAfterRestart(self): self.skipTest("b/170910141") cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 range_dataset = dataset_ops.Dataset.range(num_elements) ds = range_dataset.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=cluster.dispatcher_address(), job_name="test")) iterator = iter(ds) for i in range(num_elements // 2): self.assertEqual(i, next(iterator).numpy()) cluster.restart_dispatcher() ds = range_dataset.apply( data_service_ops.distribute(processing_mode="distributed_epoch", service=cluster.dispatcher_address(), job_name="test")) with self.assertRaisesOpError( "already an existing job with that name " "using processing mode <parallel_epochs>"): next(iter(ds)).numpy() @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(work_dir=[TMP_WORK_DIR, NO_WORK_DIR]))) def testDistributeLargeGraphThenRegisterWorker(self, work_dir): cluster = data_service_test_base.TestCluster(num_workers=0, work_dir=work_dir, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) it = iter(ds) cluster.add_worker() self.assertAllEqual(next(it), tensor)
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testFromValue(self): opt = optional_ops.Optional.from_value(constant_op.constant(37.0)) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual(37.0, self.evaluate(opt.get_value())) @combinations.generate(test_base.default_test_combinations()) def testFromStructuredValue(self): opt = optional_ops.Optional.from_value({ "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual({ "a": 37.0, "b": ([b"Foo"], b"Bar") }, self.evaluate(opt.get_value())) @combinations.generate(test_base.default_test_combinations()) def testFromSparseTensor(self): st_0 = sparse_tensor.SparseTensorValue(indices=np.array([[0]]), values=np.array([0], dtype=np.int64), dense_shape=np.array([1])) st_1 = sparse_tensor.SparseTensorValue( indices=np.array([[0, 0], [1, 1]]), values=np.array([-1., 1.], dtype=np.float32), dense_shape=np.array([2, 2])) opt = optional_ops.Optional.from_value((st_0, st_1)) self.assertTrue(self.evaluate(opt.has_value())) val_0, val_1 = opt.get_value() for expected, actual in [(st_0, val_0), (st_1, val_1)]: self.assertAllEqual(expected.indices, self.evaluate(actual.indices)) self.assertAllEqual(expected.values, self.evaluate(actual.values)) self.assertAllEqual(expected.dense_shape, self.evaluate(actual.dense_shape)) @combinations.generate(test_base.default_test_combinations()) def testFromNone(self): value_structure = tensor_spec.TensorSpec([], dtypes.float32) opt = optional_ops.Optional.empty(value_structure) self.assertTrue(opt.element_spec.is_compatible_with(value_structure)) self.assertFalse( opt.element_spec.is_compatible_with( tensor_spec.TensorSpec([1], dtypes.float32))) self.assertFalse( opt.element_spec.is_compatible_with( tensor_spec.TensorSpec([], dtypes.int32))) self.assertFalse(self.evaluate(opt.has_value())) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(opt.get_value()) @combinations.generate(test_base.default_test_combinations()) def testAddN(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): # With value opt1 = optional_ops.Optional.from_value((1.0, 2.0)) opt2 = optional_ops.Optional.from_value((3.0, 4.0)) add_tensor = math_ops.add_n( [opt1._variant_tensor, opt2._variant_tensor]) add_opt = optional_ops._OptionalImpl(add_tensor, opt1.element_spec) self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0)) # Without value opt_none1 = optional_ops.Optional.empty(opt1.element_spec) opt_none2 = optional_ops.Optional.empty(opt2.element_spec) add_tensor = math_ops.add_n( [opt_none1._variant_tensor, opt_none2._variant_tensor]) add_opt = optional_ops._OptionalImpl(add_tensor, opt_none1.element_spec) self.assertFalse(self.evaluate(add_opt.has_value())) @combinations.generate(test_base.default_test_combinations()) def testNestedAddN(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): opt1 = optional_ops.Optional.from_value([1, 2.0]) opt2 = optional_ops.Optional.from_value([3, 4.0]) opt3 = optional_ops.Optional.from_value( (5.0, opt1._variant_tensor)) opt4 = optional_ops.Optional.from_value( (6.0, opt2._variant_tensor)) add_tensor = math_ops.add_n( [opt3._variant_tensor, opt4._variant_tensor]) add_opt = optional_ops._OptionalImpl(add_tensor, opt3.element_spec) self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0) inner_add_opt = optional_ops._OptionalImpl( add_opt.get_value()[1], opt1.element_spec) self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0]) @combinations.generate(test_base.default_test_combinations()) def testZerosLike(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): # With value opt = optional_ops.Optional.from_value((1.0, 2.0)) zeros_tensor = array_ops.zeros_like(opt._variant_tensor) zeros_opt = optional_ops._OptionalImpl(zeros_tensor, opt.element_spec) self.assertAllEqual(self.evaluate(zeros_opt.get_value()), (0.0, 0.0)) # Without value opt_none = optional_ops.Optional.empty(opt.element_spec) zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor) zeros_opt = optional_ops._OptionalImpl(zeros_tensor, opt_none.element_spec) self.assertFalse(self.evaluate(zeros_opt.has_value())) @combinations.generate(test_base.default_test_combinations()) def testNestedZerosLike(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): opt1 = optional_ops.Optional.from_value(1.0) opt2 = optional_ops.Optional.from_value(opt1._variant_tensor) zeros_tensor = array_ops.zeros_like(opt2._variant_tensor) zeros_opt = optional_ops._OptionalImpl(zeros_tensor, opt2.element_spec) inner_zeros_opt = optional_ops._OptionalImpl( zeros_opt.get_value(), opt1.element_spec) self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0) @combinations.generate(test_base.default_test_combinations()) def testCopyToGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/cpu:0"): optional_with_value = optional_ops.Optional.from_value( (constant_op.constant(37.0), constant_op.constant("Foo"), constant_op.constant(42))) optional_none = optional_ops.Optional.empty( tensor_spec.TensorSpec([], dtypes.float32)) with ops.device("/gpu:0"): gpu_optional_with_value = optional_ops._OptionalImpl( array_ops.identity(optional_with_value._variant_tensor), optional_with_value.element_spec) gpu_optional_none = optional_ops._OptionalImpl( array_ops.identity(optional_none._variant_tensor), optional_none.element_spec) gpu_optional_with_value_has_value = gpu_optional_with_value.has_value( ) gpu_optional_with_value_values = gpu_optional_with_value.get_value( ) gpu_optional_none_has_value = gpu_optional_none.has_value() self.assertTrue(self.evaluate(gpu_optional_with_value_has_value)) self.assertEqual((37.0, b"Foo", 42), self.evaluate(gpu_optional_with_value_values)) self.assertFalse(self.evaluate(gpu_optional_none_has_value)) @combinations.generate(test_base.default_test_combinations()) def testNestedCopyToGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/cpu:0"): optional_with_value = optional_ops.Optional.from_value( (constant_op.constant(37.0), constant_op.constant("Foo"), constant_op.constant(42))) optional_none = optional_ops.Optional.empty( tensor_spec.TensorSpec([], dtypes.float32)) nested_optional = optional_ops.Optional.from_value( (optional_with_value._variant_tensor, optional_none._variant_tensor, 1.0)) with ops.device("/gpu:0"): gpu_nested_optional = optional_ops._OptionalImpl( array_ops.identity(nested_optional._variant_tensor), nested_optional.element_spec) gpu_nested_optional_has_value = gpu_nested_optional.has_value() gpu_nested_optional_values = gpu_nested_optional.get_value() self.assertTrue(self.evaluate(gpu_nested_optional_has_value)) inner_with_value = optional_ops._OptionalImpl( gpu_nested_optional_values[0], optional_with_value.element_spec) inner_none = optional_ops._OptionalImpl(gpu_nested_optional_values[1], optional_none.element_spec) self.assertEqual((37.0, b"Foo", 42), self.evaluate(inner_with_value.get_value())) self.assertFalse(self.evaluate(inner_none.has_value())) self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2])) @combinations.generate( combinations.times(test_base.default_test_combinations(), _optional_spec_test_combinations())) def testOptionalSpec(self, tf_value_fn, expected_value_structure): tf_value = tf_value_fn() opt = optional_ops.Optional.from_value(tf_value) self.assertTrue( structure.are_compatible(opt.element_spec, expected_value_structure)) opt_structure = structure.type_spec_from_value(opt) self.assertIsInstance(opt_structure, optional_ops.OptionalSpec) self.assertTrue(structure.are_compatible(opt_structure, opt_structure)) self.assertTrue( structure.are_compatible(opt_structure._element_spec, expected_value_structure)) self.assertEqual([dtypes.variant], structure.get_flat_tensor_types(opt_structure)) self.assertEqual([tensor_shape.TensorShape([])], structure.get_flat_tensor_shapes(opt_structure)) # All OptionalSpec objects are not compatible with a non-optional # value. non_optional_structure = structure.type_spec_from_value( constant_op.constant(42.0)) self.assertFalse( opt_structure.is_compatible_with(non_optional_structure)) # Assert that the optional survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_opt = opt_structure._from_tensor_list( opt_structure._to_tensor_list(opt)) if isinstance(tf_value, optional_ops.Optional): self.assertValuesEqual( self.evaluate(tf_value.get_value()), self.evaluate(round_trip_opt.get_value().get_value())) else: self.assertValuesEqual(self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value())) @combinations.generate( combinations.times(test_base.default_test_combinations(), _get_next_as_optional_test_combinations())) def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, gpu_compatible): if not gpu_compatible and test.is_gpu_available(): self.skipTest("Test case not yet supported on GPU.") ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3) if context.executing_eagerly(): iterator = dataset_ops.make_one_shot_iterator(ds) # For each element of the dataset, assert that the optional evaluates to # the expected value. for _ in range(3): next_elem = iterator_ops.get_next_as_optional(iterator) self.assertIsInstance(next_elem, optional_ops.Optional) self.assertTrue( structure.are_compatible( next_elem.element_spec, structure.type_spec_from_value(tf_value_fn()))) self.assertTrue(next_elem.has_value()) self.assertValuesEqual(np_value, next_elem.get_value()) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. for _ in range(2): next_elem = iterator_ops.get_next_as_optional(iterator) self.assertFalse(self.evaluate(next_elem.has_value())) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(next_elem.get_value()) else: iterator = dataset_ops.make_initializable_iterator(ds) next_elem = iterator_ops.get_next_as_optional(iterator) self.assertIsInstance(next_elem, optional_ops.Optional) self.assertTrue( structure.are_compatible( next_elem.element_spec, structure.type_spec_from_value(tf_value_fn()))) # Before initializing the iterator, evaluating the optional fails with # a FailedPreconditionError. This is only relevant in graph mode. elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() with self.assertRaises(errors.FailedPreconditionError): self.evaluate(elem_has_value_t) with self.assertRaises(errors.FailedPreconditionError): self.evaluate(elem_value_t) # Now we initialize the iterator. self.evaluate(iterator.initializer) # For each element of the dataset, assert that the optional evaluates to # the expected value. for _ in range(3): elem_has_value, elem_value = self.evaluate( [elem_has_value_t, elem_value_t]) self.assertTrue(elem_has_value) self.assertValuesEqual(np_value, elem_value) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. for _ in range(2): self.assertFalse(self.evaluate(elem_has_value_t)) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(elem_value_t) @combinations.generate(test_base.default_test_combinations()) def testFunctionBoundaries(self): @def_function.function def get_optional(): x = constant_op.constant(1.0) opt = optional_ops.Optional.from_value(x) # TODO(skyewm): support returning Optionals from functions? return opt._variant_tensor # TODO(skyewm): support Optional arguments? @def_function.function def consume_optional(opt_tensor): value_structure = tensor_spec.TensorSpec([], dtypes.float32) opt = optional_ops._OptionalImpl(opt_tensor, value_structure) return opt.get_value() opt_tensor = get_optional() val = consume_optional(opt_tensor) self.assertEqual(self.evaluate(val), 1.0) @combinations.generate(test_base.default_test_combinations()) def testLimitedRetracing(self): trace_count = [0] @def_function.function def f(opt): trace_count[0] += 1 return opt.get_value() opt1 = optional_ops.Optional.from_value(constant_op.constant(37.0)) opt2 = optional_ops.Optional.from_value(constant_op.constant(42.0)) for _ in range(10): self.assertEqual(self.evaluate(f(opt1)), 37.0) self.assertEqual(self.evaluate(f(opt2)), 42.0) self.assertEqual(trace_count[0], 1)
class OptimizationTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testOptimizationStatefulFunction(self): dataset = dataset_ops.Dataset.range(10).map( lambda _: random_ops.random_uniform([])).batch(10) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) get_next = self.getNext(dataset) self.evaluate(get_next()) # TODO(b/123354468) @combinations.generate(test_base.graph_only_combinations()) def testOptimizationLargeInputFromTensor(self): input_t = array_ops.placeholder(dtypes.int32, (None, None, None)) dataset = dataset_ops.Dataset.from_tensors(input_t) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)}) self.evaluate(get_next) # TODO(b/123354468) @combinations.generate(test_base.graph_only_combinations()) def testOptimizationLargeInputFromTensorSlices(self): input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None)) dataset = dataset_ops.Dataset.from_tensor_slices(input_t) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)}) self.evaluate(get_next) @combinations.generate(test_base.default_test_combinations()) def testOptimizationNestedDataset(self): def flat_map_fn(_): dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"])) dataset = dataset.skip(0) # Should be removed by noop elimination dataset = dataset.cache() return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.flat_map(flat_map_fn) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.noop_elimination = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[0]) @combinations.generate(test_base.default_test_combinations()) def testOptimizationNestedDatasetWithModifiedRetval(self): def flat_map_fn(_): dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset.apply(testing.assert_next(["MapAndBatch"])) # Should be fused by map and batch fusion dataset = dataset.map(lambda x: x) dataset = dataset.batch(1) return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.flat_map(flat_map_fn) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.map_and_batch_fusion = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[[0]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(autotune=[True, False, None]), combinations.combine(map_parallelization=[True, False, None]))) def testOptimizationMapParallelization(self, autotune, map_parallelization): dataset = dataset_ops.Dataset.range(5) if autotune is not False and map_parallelization is not False: # pylint: disable=g-bool-id-comparison dataset = dataset.apply(testing.assert_next(["ParallelMap"])) else: dataset = dataset.apply(testing.assert_next(["Map"])) dataset = dataset.map(lambda x: x + 1) options = options_lib.Options() if autotune is not None: options.autotune.enabled = autotune if map_parallelization is not None: options.experimental_optimization.map_parallelization = ( map_parallelization) dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=list(range(1, 6))) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(existing_prefetch=[True, False]), combinations.combine(autotune=[True, False]), combinations.combine(set_env=[True, False]))) def testOptimizationInjectPrefetch(self, existing_prefetch, autotune, set_env): if set_env: os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "inject_prefetch" os.environ["TF_JOB_NAME"] = "test_job" dataset = dataset_ops.Dataset.range(5) dataset = dataset.map(lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE) if existing_prefetch: dataset = dataset.prefetch(1) if autotune and set_env and not existing_prefetch: dataset = dataset.apply(testing.assert_next(["Prefetch", "Root"])) else: dataset = dataset.apply(testing.assert_next(["Root"])) options = options_lib.Options() options.autotune.enabled = autotune dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=list(range(1, 6))) if set_env: del os.environ["TF_DATA_EXPERIMENT_OPT_IN"] del os.environ["TF_JOB_NAME"] # Reference variables are not supported in eager mode. @combinations.generate( combinations.times(test_base.graph_only_combinations(), _captured_refvar_test_combinations())) def testOptimizationWithCapturedRefVar(self, dataset_fn): """Tests that default optimizations are disabled with ref variables.""" variable = variable_scope.get_variable("v", initializer=0, use_resource=False) assign_op = variable.assign_add(1) unoptimized_dataset = dataset_fn(variable) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.noop_elimination = True options.experimental_optimization.map_and_batch_fusion = True optimized_dataset = unoptimized_dataset.with_options(options) optimized_it = dataset_ops.make_initializable_iterator( optimized_dataset) # Check that outputs are the same in the optimized and unoptimized cases, # when the variable value is changing. unoptimized_it = dataset_ops.make_initializable_iterator( unoptimized_dataset) with ops.control_dependencies([assign_op]): unoptimized_output = unoptimized_it.get_next() optimized_output = optimized_it.get_next() self.evaluate(variable.initializer) self.evaluate((unoptimized_it.initializer, optimized_it.initializer)) while True: try: unoptimized, optimized = self.evaluate( (unoptimized_output, optimized_output)) self.assertEqual(unoptimized, optimized) except errors.OutOfRangeError: break