class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testOptionsDefault(self): ds = dataset_ops.Dataset.range(0) self.assertEqual(dataset_ops.Options(), ds.options()) @combinations.generate(test_base.default_test_combinations()) def testOptionsOnce(self): options = dataset_ops.Options() ds = dataset_ops.Dataset.range(0).with_options(options).cache() self.assertEqual(options, ds.options()) @combinations.generate(test_base.default_test_combinations()) def testOptionsTwiceSame(self): options = dataset_ops.Options() options.experimental_optimization.autotune = True ds = dataset_ops.Dataset.range(0).with_options(options).with_options( options) self.assertEqual(options, ds.options()) @combinations.generate(test_base.default_test_combinations()) def testOptionsTwiceDifferentOptions(self): options1 = dataset_ops.Options() options1.experimental_optimization.autotune = True options2 = dataset_ops.Options() options2.experimental_deterministic = False ds = dataset_ops.Dataset.range(0) ds = ds.with_options(options1) ds = ds.with_options(options2) self.assertTrue(ds.options().experimental_optimization.autotune) # Explicitly check that flag is False since assertFalse allows None self.assertIs(ds.options().experimental_deterministic, False) @combinations.generate(test_base.default_test_combinations()) def testOptionsTwiceSameOption(self): if sys.version_info >= (3, 8) and platform.system() == "Windows": # TODO(b/165013260): Fix this self.skipTest( "Test is currently broken on Windows with Python 3.8") options1 = dataset_ops.Options() options1.experimental_optimization.autotune = False options2 = dataset_ops.Options() options2.experimental_optimization.autotune = True ds = dataset_ops.Dataset.range(0) ds = ds.with_options(options1) ds = ds.with_options(options2) self.assertTrue(ds.options().experimental_optimization.autotune) @combinations.generate(test_base.default_test_combinations()) def testOptionsMergeOptionsFromMultipleInputs(self): options1 = dataset_ops.Options() options1.experimental_optimization.autotune = True options2 = dataset_ops.Options() options2.experimental_deterministic = True ds1 = dataset_ops.Dataset.range(0).with_options(options1) ds2 = dataset_ops.Dataset.range(0).with_options(options2) ds = dataset_ops.Dataset.zip((ds1, ds2)) self.assertTrue(ds.options().experimental_optimization.autotune) self.assertTrue(ds.options().experimental_deterministic) @combinations.generate(test_base.default_test_combinations()) def testOptionsHaveDefaults(self): options1 = dataset_ops.Options() options2 = dataset_ops.Options() self.assertIsNot(options1.experimental_optimization, options2.experimental_optimization) self.assertIsNot(options1.experimental_stats, options2.experimental_stats) self.assertIsNot(options1.experimental_threading, options2.experimental_threading) self.assertEqual(options1.experimental_optimization, optimization_options.OptimizationOptions()) self.assertEqual(options1.experimental_stats, stats_options.StatsOptions()) self.assertEqual(options1.experimental_threading, threading_options.ThreadingOptions()) @combinations.generate(test_base.default_test_combinations()) def testMutatingOptionsRaiseValueError(self): ds = dataset_ops.Dataset.range(0) options1 = dataset_ops.Options() options1.experimental_slack = True options2 = dataset_ops.Options() options2.experimental_optimization.autotune = True ds = ds.with_options(options1) ds = ds.map(lambda x: 2 * x) ds = ds.with_options(options2) with self.assertRaises(ValueError): dataset_options = ds.options() dataset_options.experimental_deterministic = True @combinations.generate(test_base.eager_only_combinations()) def testNestedDataset(self): ds = dataset_ops.Dataset.from_tensors(0) result = ds for _ in range(999): result = result.concatenate(ds) self.assertDatasetProduces(result, [0] * 1000) @combinations.generate(test_base.default_test_combinations()) def testOptionsProtoRoundTrip(self): options = dataset_ops.Options() options.experimental_deterministic = True options.experimental_external_state_policy = ( distribute_options.ExternalStatePolicy.FAIL) options.experimental_distribute.auto_shard_policy = ( distribute_options.AutoShardPolicy.DATA) options.experimental_distribute.num_devices = 1000 options.experimental_optimization.apply_default_optimizations = True options.experimental_optimization.autotune = True options.experimental_optimization.autotune_buffers = True options.experimental_optimization.autotune_cpu_budget = 10 options.experimental_optimization.autotune_ram_budget = 20 options.experimental_optimization.filter_fusion = True options.experimental_optimization.filter_with_random_uniform_fusion = True options.experimental_optimization.hoist_random_uniform = True options.experimental_optimization.map_and_batch_fusion = True options.experimental_optimization.map_and_filter_fusion = True options.experimental_optimization.map_fusion = True options.experimental_optimization.map_parallelization = True options.experimental_optimization.map_vectorization.enabled = True options.experimental_optimization.map_vectorization.use_choose_fastest = ( True) options.experimental_optimization.noop_elimination = True options.experimental_optimization.parallel_batch = True options.experimental_optimization.reorder_data_discarding_ops = True options.experimental_optimization.shuffle_and_repeat_fusion = True options.experimental_slack = True options.experimental_threading.max_intra_op_parallelism = 30 options.experimental_threading.private_threadpool_size = 40 pb = options._to_proto() result = dataset_ops.Options() result._from_proto(pb) self.assertEqual(options, result) @combinations.generate(test_base.default_test_combinations()) def testOptionsProtoDefaultValuesRoundTrip(self): options = dataset_ops.Options() pb = options._to_proto() result = dataset_ops.Options() result._from_proto(pb) self.assertEqual(options, result) @combinations.generate(test_base.default_test_combinations()) def testProtoOptionsDefaultValuesRoundTrip(self): pb = dataset_options_pb2.Options() options = dataset_ops.Options() options._from_proto(pb) result = options._to_proto() expected_pb = dataset_options_pb2.Options() expected_pb.distribute_options.CopyFrom( dataset_options_pb2.DistributeOptions()) expected_pb.optimization_options.CopyFrom( dataset_options_pb2.OptimizationOptions()) expected_pb.optimization_options.map_vectorization.CopyFrom( dataset_options_pb2.MapVectorization()) expected_pb.threading_options.CopyFrom( dataset_options_pb2.ThreadingOptions()) self.assertProtoEquals(expected_pb, result)
class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( skip=[0, 5, 10], take=[1], error=[None], error_msg=[None]) + combinations.combine(skip=[100], take=[1], error=[errors.InvalidArgumentError], error_msg=["Dataset was empty."]) + combinations.combine( skip=[0], take=[2], error=[errors.InvalidArgumentError], error_msg=["Dataset had more than one element."]))) def testGetSingleElement(self, skip, take, error=None, error_msg=None): def make_sparse(x): x_1d = array_ops.reshape(x, [1]) x_2d = array_ops.reshape(x, [1, 1]) return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d) dataset = dataset_ops.Dataset.range(100).skip(skip).map( lambda x: (x * x, make_sparse(x))).take(take) if error is None: dense_val, sparse_val = self.evaluate( get_single_element.get_single_element(dataset)) self.assertEqual(skip * skip, dense_val) self.assertAllEqual([[skip]], sparse_val.indices) self.assertAllEqual([skip], sparse_val.values) self.assertAllEqual([skip], sparse_val.dense_shape) else: with self.assertRaisesRegexp(error, error_msg): self.evaluate(get_single_element.get_single_element(dataset)) @combinations.generate(test_base.default_test_combinations()) def testWindow(self): """Test that `get_single_element()` can consume a nested dataset.""" def flat_map_func(ds): batched = ds.batch(2) element = get_single_element.get_single_element(batched) return dataset_ops.Dataset.from_tensors(element) dataset = dataset_ops.Dataset.range(10).window(2).flat_map( flat_map_func) self.assertDatasetProduces(dataset, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) @combinations.generate(test_base.default_test_combinations()) def testSideEffect(self): counter_var = variables.Variable(0) def increment_fn(x): counter_var.assign_add(1) return x def dataset_fn(): return dataset_ops.Dataset.range(1).map(increment_fn) @function.defun def fn(): _ = get_single_element.get_single_element(dataset_fn()) return "hello" self.evaluate(counter_var.initializer) self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(counter_var), 1) @combinations.generate(test_base.default_test_combinations()) def testAutomaticControlDependencies(self): counter_var = variables.Variable(1) def increment_fn(x): counter_var.assign(counter_var + 1) return x def multiply_fn(x): counter_var.assign(counter_var * 2) return x def dataset1_fn(): return dataset_ops.Dataset.range(1).map(increment_fn) def dataset2_fn(): return dataset_ops.Dataset.range(1).map(multiply_fn) @function.defun def fn(): _ = get_single_element.get_single_element(dataset1_fn()) _ = get_single_element.get_single_element(dataset2_fn()) return "hello" self.evaluate(counter_var.initializer) self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(counter_var), 4)
class LocalTaskGarbageCollectTest(data_service_test_base.TestBase, parameterized.TestCase): """Tests garbage collecting unused local worker tasks. The user typically creates an iterator in each epoch. This should delete the previous iterator and releases the resources of it. """ @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testMultipleEpochs(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_epochs, num_steps = 5, 5 dataset = self._make_distributed_infinite_range_dataset(cluster) for _ in range(num_epochs): # For each iteration, the previous iterator is garbage collected. get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testMultipleEpochsSharedJob(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_epochs, num_steps = 5, 5 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") for _ in range(num_epochs): # For each iteration, the previous iterator is garbage collected. get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_remote_workers=[0, 3], job_name=[None, "shared_job_name"]))) def testRepeatDistributedDataset(self, num_remote_workers, job_name): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) dataset = self.make_distributed_range_dataset(10, cluster, job_name=job_name, target_workers="LOCAL") dataset = dataset.repeat(3) self.assertDatasetProduces(dataset, list(range(10)) * 3) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testReadFromDeletedTask(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_steps = 10 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) # Re-creating the dataset resets the iterator index, so the second iterator # reads from the same task as the first, which has been deleted. dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") get_next = self.getNext(dataset) with self.assertRaisesRegex(errors.FailedPreconditionError, "which has been deleted."): _ = self.evaluate(get_next()) @combinations.generate( combinations.times(test_base.graph_only_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testReadFromDeletedTask_GraphMode(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_steps = 10 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") with self.session() as sess: get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(sess.run(get_next()), i) # Re-creating the dataset resets the iterator index, so the second iterator # reads from the same task as the first, which has been deleted. dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") with self.assertRaisesRegex(errors.FailedPreconditionError, "which has been deleted."): with self.session() as sess: get_next = self.getNext(dataset) sess.run(get_next()) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testMultipleEpochs_WorkerRestart(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_steps = 10 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) # Verifies the worker re-creates the task after the iterator is deleted and # the worker restarts. del get_next cluster.restart_local_workers() get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(num_remote_workers=[0, 3]))) def testMultipleEpochs_DispatcherRestart(self, num_remote_workers): num_local_workers = 1 cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_steps = 10 dataset = self._make_distributed_infinite_range_dataset( cluster, job_name="shared_job_name") get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) # Verifies the worker re-creates the task after the iterator is deleted and # the dispatcher restarts. del get_next cluster.restart_dispatcher() get_next = self.getNext(dataset) for i in range(num_steps): self.assertEqual(self.evaluate(get_next()), i) def _make_distributed_infinite_range_dataset(self, cluster, job_name=None): dataset = dataset_ops.Dataset.range(1000000).repeat() return self.make_distributed_dataset( dataset, cluster=cluster, job_name=job_name, processing_mode=ShardingPolicy.OFF, target_workers="LOCAL")
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase): # TODO(b/121264236) @combinations.generate( combinations.combine(tf_api_version=[1], mode=["graph"])) def testPrefetchWithSlackOption(self): """Determines slack_period based on num devices attached to iterator.""" dataset = dataset_ops.Dataset.range(10) dataset = dataset.prefetch(1) options = dataset_ops.Options() options.experimental_slack = True dataset = dataset.with_options(options) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"]) dataset = multi_device_iterator._dataset # pylint: disable=protected-access config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config): self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.default_test_combinations()) def testPrefetchWithSlackOptionWithoutIterator(self): """Defaults to slack period of 1 without iterator.""" dataset = dataset_ops.Dataset.range(10) dataset = dataset.prefetch(1) options = dataset_ops.Options() options.experimental_slack = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, range(10)) @combinations.generate(test_base.default_test_combinations()) def testWithPassthroughDataset(self): """Should still work with a passthrough dataset after prefetch().""" dataset = dataset_ops.Dataset.range(10) dataset = dataset.prefetch(1) dataset = dataset.map(lambda x: x + 1) options = dataset_ops.Options() options.experimental_slack = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, range(1, 11)) @combinations.generate(test_base.default_test_combinations()) def testNoErrorWithoutPrefetch(self): """The rewrite should not fail if there is no prefetch() in the pipeline.""" dataset = dataset_ops.Dataset.range(10) options = dataset_ops.Options() options.experimental_slack = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, range(10)) @combinations.generate(test_base.default_test_combinations()) def testNoErrorWithInvalidDataset(self): """With a nested dataset op after prefetch, the rewrite should fail.""" dataset = dataset_ops.Dataset.range(10) dataset = dataset.prefetch(1) dataset = dataset.flat_map(dataset_ops.Dataset.from_tensors) options = dataset_ops.Options() options.experimental_slack = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, range(10))
class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testBasic(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual( [[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testCanHandleUnknownRank(self): dataset = dataset_ops.Dataset.from_tensors("xxx") # decode_image results in a tensor of completely unknown shape (i.e. unknown # rank) dataset = dataset.map(image_ops.decode_image) self.assertEqual([tensor_shape.TensorShape(None)], _flat_shapes(dataset)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual([tensor_shape.TensorShape(None)], _flat_shapes(rebatched_dataset)) @combinations.generate(test_base.default_test_combinations()) def testCanHandleUnknownDims(self): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.batch(10, drop_remainder=False) dataset = dataset.batch(10, drop_remainder=False) self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(dataset)]) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual( [[None, None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) @combinations.generate(test_base.default_test_combinations()) def testScalarInputError(self): dataset = dataset_ops.Dataset.range(1024) distribute._RebatchDataset(dataset.batch(4), num_replicas=4) with self.assertRaisesRegex(ValueError, ("You can fix the issue " "by adding the `batch`")): distribute._RebatchDataset(dataset, num_replicas=4) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testBatchNotDivisibleByNumReplicas(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) self.assertEqual( [[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [] i = 0 for _ in range(32): # number of steps # first four minibatches have seven elements for _ in range(4): expected_output.append([k for k in range(i, i + 7)]) i += 7 # last minibatch has four elements expected_output.append([k for k in range(i, i + 4)]) i += 4 self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testBatchSizeNotDivisibleByNumReplicas2(self): dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) # This will rebatch into sub-batches of size 4, since # ceil(16 / 5) = 4. However, that means only the first 4 replicas will get # data. expected_output = [[k for k in range(i, i + 4)] for i in range(0, 16, 4)] expected_output.extend([[]]) # Last replica gets an empty batch expected_output.extend([[k for k in range(i, i + 4)] for i in range(16, 32, 4)]) expected_output.extend([[]]) # Last replica gets an empty batch self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testTupleOutput(self): dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch( 32) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [ ( [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension [k for k in range(i, i + 8)]) for i in range(0, 1024, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testNestedDictionaryOutput(self): dataset = dataset_ops.Dataset.range(1024).map(lambda x: { "a": x, "b": { "c": x } }).batch(32) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [ { "a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension "b": { "c": [k for k in range(i, i + 8)] } } for i in range(0, 1024, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testFinalPartialBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(1032).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual( [[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # if drop_remainder, the final partial batch is dropped, even though it # makes up a complete minibatch. expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension if not drop_remainder: # The last partial batch of size 8 is split over 4 replicas expected_output.extend([[k for k in range(i, i + 2)] for i in range(1024, 1032, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) def testFinalPartialBatchAfterRebatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(34).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual( [[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] # pylint: disable=g-complex-comprehension if not drop_remainder: # The last partial batch of size 2 is split over 4 replicas expected_output += [[32], [33], [], []] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testMultipleBatches(self): dataset = dataset_ops.Dataset.range(128).batch(4).batch(8) self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Each element is a list of 8 elements where each element is a list of 4. expected_output = [ [ [j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4) ] # generates 8 elements for i in range(0, 128, 32) ] self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, 4) self.assertEqual( [[None, None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Each element is a list of 2 elements where each element is a list of 4. expected_output = [ [ [j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4) ] # generates 2 elements for i in range(0, 128, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testRaggedTensorDataset(self): # Set up a dataset that produces ragged tensors with a static batch size. row_lengths = np.random.randint(8, size=128) values = np.random.normal(size=np.sum(row_lengths)).astype(np.float32) dataset = dataset_ops.Dataset.from_tensor_slices( ragged_tensor.RaggedTensor.from_row_lengths(values, row_lengths)) dataset = dataset.batch(32, drop_remainder=True) # The map changes the internal representation of the ragged tensor. # This test will fail if we don't normalize the tensor representation. dataset = dataset.map(lambda x: x) dataset = distribute._RebatchDataset(dataset, num_replicas=8) # After rebatching, batch size is now 4. expected_output = [] value_index = 0 for batch_row_lengths in row_lengths.reshape((-1, 4)): num_values = np.sum(batch_row_lengths) expected_output.append( ragged_tensor.RaggedTensor.from_row_lengths( values[value_index:(value_index + num_values)], batch_row_lengths)) value_index += num_values self.assertDatasetProduces(dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testNoOutputShapes(self): # Some datasets, e.g. datasets with None tensors, have components without # output shapes. Test that this doesn't break rebatching shape inference # logic. dataset = dataset_ops.Dataset.range(4) dataset = dataset.map(lambda x: (x, None)) dataset = dataset.batch(4, drop_remainder=True) _ = distribute._RebatchDataset(dataset, num_replicas=2)
class AutoShardWithRebatchDatasetTest( reader_dataset_ops_test_base.TFRecordDatasetTestBase, parameterized.TestCase): def _setUpFiles(self, num_files, num_records_per_file): self._num_files = num_files self._num_records = num_records_per_file self.test_filenames = self._createFiles() @combinations.generate(test_base.default_test_combinations()) def testFileShardingWithLegacyRebatch(self): # Tests that RebatchDatasetV1 is a passthrough op. self._setUpFiles(num_files=5, num_records_per_file=10) dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [[self._record(3, i)] for i in range(10)] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testFileShardingWithRebatch(self): # Tests that RebatchDatasetV2 is a passthrough op. self._setUpFiles(num_files=3, num_records_per_file=5) dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._RebatchDataset(dataset, batch_sizes=[2, 1, 2]) dataset = distribute._AutoShardDataset(dataset, 3, 1) expected = [[self._record(1, 0), self._record(1, 1)], [self._record(1, 2)], [self._record(1, 3), self._record(1, 4)]] self.assertDatasetProduces(dataset, expected) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.times( combinations.combine(sharding_policy=[ distribute_options.AutoShardPolicy.DATA, distribute_options.AutoShardPolicy.AUTO ]), combinations.combine(with_prefetch=[True, False])))) def testUseLegacyRebatchWithDataSharding(self, sharding_policy, with_prefetch): # This test simulates a distributed environment with 3 workers, each with # 1 replica. dataset = dataset_ops.Dataset.range(8) dataset = dataset.batch(4) options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = sharding_policy dataset = dataset.with_options(options) # We expect the auto-shard rewrite to rewrite RebatchDatasetV2 to # RebatchDataset(V1) for correctness reasons. This will modify the output # of the dataset. worker_a_dataset = distribute._RebatchDataset( dataset, batch_sizes=[2, 1, 1]) if with_prefetch: worker_a_dataset = worker_a_dataset.prefetch(1) worker_a_dataset = distribute._AutoShardDataset( worker_a_dataset, 3, 0, num_replicas=3) expected = [[0, 1], [4, 5]] self.assertDatasetProduces(worker_a_dataset, expected) worker_b_dataset = distribute._RebatchDataset( dataset, batch_sizes=[1, 1, 2]) if with_prefetch: worker_b_dataset = worker_b_dataset.prefetch(1) worker_b_dataset = distribute._AutoShardDataset( worker_b_dataset, 3, 1, num_replicas=3) expected = [[2, 3], [6, 7]] self.assertDatasetProduces(worker_b_dataset, expected) worker_c_dataset = distribute._RebatchDataset( dataset, batch_sizes=[1, 2, 1]) if with_prefetch: worker_c_dataset = worker_c_dataset.prefetch(1) worker_c_dataset = distribute._AutoShardDataset( worker_c_dataset, 3, 2, num_replicas=3) expected = [[], []] self.assertDatasetProduces(worker_c_dataset, expected)
class ShuffleDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase, parameterized.TestCase): def _build_shuffle_dataset( self, range_limit=10, num_repeats=5, buffer_size=5, seed=None, reshuffle_each_iteration=None, ): return dataset_ops.Dataset.range(range_limit).shuffle( buffer_size, seed=seed, reshuffle_each_iteration=reshuffle_each_iteration).repeat( num_repeats) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(reshuffle_each_iteration=[True, False], buffer_size=[1, 3, 5, 8, 10]))) def testShuffleCore(self, reshuffle_each_iteration, buffer_size): seed = 55 range_limit = 5 num_repeats = 2 num_outputs = range_limit * num_repeats # pylint: disable=g-long-lambda self.run_core_tests( lambda: self._build_shuffle_dataset(range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, seed=seed, reshuffle_each_iteration= reshuffle_each_iteration), num_outputs) @combinations.generate( combinations.combine(tf_api_version=1, mode=["graph"], reshuffle_each_iteration=[True, False], buffer_size=[1, 3, 5, 8, 10])) def testMultipleIterators(self, reshuffle_each_iteration, buffer_size): range_limit = 5 num_repeats = 2 num_outputs = range_limit * num_repeats def ds_fn(): # pylint: disable=cell-var-from-loop return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, seed=None, # Iterator seeds are generated non-deterministically. reshuffle_each_iteration=reshuffle_each_iteration) # pylint: enable=cell-var-from-loop with ops.Graph().as_default() as g: ds = ds_fn() iterators = [ ds.make_one_shot_iterator(), ds.make_one_shot_iterator() ] get_next_ops = [it.get_next() for it in iterators] saveables = [ contrib_iterator_ops.make_saveable_from_iterator(it) for it in iterators ] for saveable in saveables: ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) saver = saver_lib.Saver(allow_empty=True) with self.session(graph=g) as sess: self._save(sess, saver) expected = [ self.evaluate(get_next_ops) for _ in range(num_outputs) ] self._restore(saver, sess) actual = [ self.evaluate(get_next_ops) for _ in range(num_outputs) ] self.match(expected, actual)
class FileCacheTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): super(FileCacheTest, self).setUp() self.tmp_dir = tempfile.mkdtemp() self.cache_prefix = path.join(self.tmp_dir, "cache") def tearDown(self): if self.tmp_dir: shutil.rmtree(self.tmp_dir, ignore_errors=True) super(FileCacheTest, self).tearDown() @combinations.generate(test_base.default_test_combinations()) def testCacheDatasetPassthrough(self): components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), np.array([9.0, 10.0, 11.0, 12.0])) def dataset_fn(count=5, filename=None): repeat_dataset = ( dataset_ops.Dataset.from_tensor_slices(components).repeat(count)) if filename: return repeat_dataset.cache(filename) else: return repeat_dataset self.assertEqual( tuple([c.shape[1:] for c in components]), dataset_ops.get_legacy_output_shapes(dataset_fn())) get_next = self.getNext(dataset_fn()) # First run without caching to collect the "ground truth". elements = [] for _ in range(20): elements.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # Assert that the cached dataset has the same elements as the # "ground truth". get_next = self.getNext(dataset_fn(filename=self.cache_prefix)) cached_elements = [] for _ in range(20): cached_elements.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertAllEqual(elements, cached_elements) # Re-initialize with an empty upstream (to throw errors.OutOfRangeError # if we didn't use the cache). get_next = self.getNext(dataset_fn(count=0, filename=self.cache_prefix)) replayed_elements = [] for _ in range(20): replayed_elements.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertEqual(cached_elements, replayed_elements) # Re-initialize with an empty upstream and a missing cache file (should # throw errors.OutOfRangeError immediately). get_next = self.getNext( dataset_fn(count=0, filename=self.cache_prefix + "nonsense")) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testConcurrentWriters(self): components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), np.array([9.0, 10.0, 11.0, 12.0])) cache_dataset1 = ( dataset_ops.Dataset.from_tensor_slices(components).cache( self.cache_prefix)) cache_dataset2 = ( dataset_ops.Dataset.from_tensor_slices(components).cache( self.cache_prefix)) get_next1 = self.getNext(cache_dataset1) get_next2 = self.getNext(cache_dataset2) self.evaluate(get_next1()) # this should succeed with self.assertRaises(errors.AlreadyExistsError): self.evaluate(get_next2()) self.evaluate(get_next1()) # this should continue to succeed @combinations.generate(test_base.default_test_combinations()) def testConcurrentReaders(self): components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), np.array([9.0, 10.0, 11.0, 12.0])) cache_dataset1 = ( dataset_ops.Dataset.from_tensor_slices(components).cache( self.cache_prefix)) cache_dataset2 = ( dataset_ops.Dataset.from_tensor_slices(components).cache( self.cache_prefix)) get_next1 = self.getNext(cache_dataset1) get_next2 = self.getNext(cache_dataset2) elements = [] for _ in range(4): elements.append(self.evaluate(get_next1())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next1()) # Re-initialize get_next1 = self.getNext(cache_dataset1, requires_initialization=True) get_next2 = self.getNext(cache_dataset2, requires_initialization=True) # Reading concurrently should succeed. elements_itr1 = [] elements_itr2 = [] elements_itr2.append(self.evaluate(get_next2())) elements_itr1.append(self.evaluate(get_next1())) elements_itr2.append(self.evaluate(get_next2())) elements_itr1.append(self.evaluate(get_next1())) # Intentionally reversing the order elements_itr1.append(self.evaluate(get_next1())) elements_itr2.append(self.evaluate(get_next2())) elements_itr1.append(self.evaluate(get_next1())) elements_itr2.append(self.evaluate(get_next2())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next2()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next1()) self.assertAllEqual(elements, elements_itr1) self.assertAllEqual(elements, elements_itr2) @combinations.generate(test_base.default_test_combinations()) def testReadingPastEndOfSequence(self): dataset = dataset_ops.Dataset.range(10).cache(self.cache_prefix) dataset = dataset.map(lambda a: a).batch(4).repeat(2) expected_output = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]] * 2 self.assertDatasetProduces(dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testCleaningUpCacheFiles(self): def do_test(i): dataset = dataset_ops.Dataset.range(10).cache(self.cache_prefix) get_next = self.getNext(dataset) for _ in range(i): try: self.evaluate(get_next()) except errors.OutOfRangeError: break if not context.executing_eagerly(): self.skipTest( "Test requires eager mode for iterators to be deconstructed") for i in [0, 3, 10, 12, 15]: do_test(i)
class MultiDeviceIteratorTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): super(MultiDeviceIteratorTest, self).setUp() self._devices = self.configureDevicesForMultiDeviceTest(3) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(num_inits=[0, 1, 42]))) def testInitOnly(self, num_inits): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]]) for _ in range(num_inits): self.evaluate(multi_device_iterator.initializer) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(max_buffer_size=[0, 1, 10], prefetch_buffer_size=[0, 1, 10]))) def testBasic(self, prefetch_buffer_size, max_buffer_size): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]], max_buffer_size=max_buffer_size, prefetch_buffer_size=prefetch_buffer_size) self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.default_test_combinations()) def testOneOnSameDevice(self): dataset = dataset_ops.Dataset.range(12) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[0], self._devices[1], self._devices[2]]) self.evaluate(multi_device_iterator.initializer) for i in range(0, 12, 3): elem_on_0, elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_0)) self.assertEqual(i + 1, self.evaluate(elem_on_1)) self.assertEqual(i + 2, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_0, elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_0) self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.default_test_combinations()) def testRepeatDevices(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[1]]) self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elements = multi_device_iterator.get_next() elem_on_1, elem_on_2 = elements self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elements = multi_device_iterator.get_next() elem_on_1, elem_on_2 = elements self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.default_test_combinations()) def testNotFullyDivisible(self): dataset = dataset_ops.Dataset.range(9) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]]) self.evaluate(multi_device_iterator.initializer) for i in range(0, 8, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) elem_on_1 = multi_device_iterator.get_next(self._devices[1]) self.assertEqual(8, self.evaluate(elem_on_1)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.default_test_combinations()) def testGetNextAsOptional(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]]) self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional() has_elem_1, get_elem_1 = self.evaluate( [elem_on_1.has_value(), elem_on_1.get_value()]) has_elem_2, get_elem_2 = self.evaluate( [elem_on_2.has_value(), elem_on_2.get_value()]) self.assertTrue(has_elem_1) self.assertEqual(i, get_elem_1) self.assertTrue(has_elem_2) self.assertEqual(i + 1, get_elem_2) elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional() has_elem_1 = elem_on_1.has_value() has_elem_2 = elem_on_2.has_value() self.assertFalse(self.evaluate(has_elem_1)) self.assertFalse(self.evaluate(has_elem_2)) with self.assertRaises(errors.InvalidArgumentError): elem_1 = elem_on_1.get_value() self.evaluate(elem_1) with self.assertRaises(errors.InvalidArgumentError): elem_2 = elem_on_2.get_value() self.evaluate(elem_2) @combinations.generate(test_base.default_test_combinations()) def testUneven(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]], max_buffer_size=4) self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1 = multi_device_iterator.get_next(self._devices[1]) self.assertEqual(i, self.evaluate(elem_on_1)) for i in range(0, 10, 2): elem_on_2 = multi_device_iterator.get_next(self._devices[2]) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.graph_only_combinations()) def testMultipleInitializationsGraph(self): dataset1 = dataset_ops.Dataset.range(1000) dataset2 = dataset_ops.Dataset.range(1000) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]], prefetch_buffer_size=4) elem_on_1, elem_on_2 = multi_device_iterator.get_next() for _ in range(5): self.evaluate(multi_device_iterator.initializer) self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2])) @combinations.generate(test_base.eager_only_combinations()) def testMultipleInitializationsEager(self): dataset1 = dataset_ops.Dataset.range(1000) dataset2 = dataset_ops.Dataset.range(1000) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) for _ in range(5): multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]], prefetch_buffer_size=4) self.evaluate(multi_device_iterator.initializer) elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2])) @combinations.generate(test_base.default_test_combinations()) def testOptimization(self): dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"])) dataset = dataset.skip(0) # this should be optimized away dataset = dataset.cache() options = options_lib.Options() options.experimental_optimization.noop_elimination = True dataset = dataset.with_options(options) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, [self._devices[1], self._devices[2]]) self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2)
class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testCacheDatasetPassthrough(self): with ops.device("cpu:0"): repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64)) dataset = dataset_ops.Dataset.range(3).flat_map( lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count)) cached_dataset = dataset.cache().repeat(2) uncached_dataset = dataset.repeat(2) self.evaluate(repeat_count.initializer) # Needs to be initializable to capture the variable. cached_next = self.getNext(cached_dataset, requires_initialization=True) uncached_next = self.getNext( uncached_dataset, requires_initialization=True) for i in range(3): for _ in range(10): self.assertEqual(self.evaluate(cached_next()), i) self.assertEqual(self.evaluate(uncached_next()), i) self.evaluate(repeat_count.assign(0)) # The uncached iterator should now be empty. with self.assertRaises(errors.OutOfRangeError): self.evaluate(uncached_next()) # The cached iterator replays from cache. for i in range(3): for _ in range(10): self.assertEqual(self.evaluate(cached_next()), i) # The cached iterator should now be empty. with self.assertRaises(errors.OutOfRangeError): self.evaluate(cached_next()) @combinations.generate(test_base.default_test_combinations()) def testEmptyCacheReading(self): components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), np.array([9.0, 10.0, 11.0, 12.0])) repeat_dataset = ( dataset_ops.Dataset.from_tensor_slices(components).repeat(0)) cache_dataset = repeat_dataset.cache() self.assertDatasetProduces(cache_dataset, expected_output=[]) @combinations.generate(test_base.default_test_combinations()) def testConcurrentReaders(self): dataset_fn = lambda: dataset_ops.Dataset.range(5).cache() d1 = dataset_fn().map(lambda x: x + 1) d2 = dataset_fn().map(lambda x: x + 6) get_next1 = self.getNext(d1) self.assertEqual(1, self.evaluate(get_next1())) self.assertEqual(2, self.evaluate(get_next1())) self.assertEqual(3, self.evaluate(get_next1())) get_next2 = self.getNext(d2) self.assertEqual(6, self.evaluate(get_next2())) self.assertEqual(7, self.evaluate(get_next2())) self.assertEqual(4, self.evaluate(get_next1())) # interleave execution self.assertEqual([8, 5], [self.evaluate(get_next2()), self.evaluate(get_next1())]) self.assertEqual(9, self.evaluate(get_next2())) self.assertEqual(10, self.evaluate(get_next2())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next2()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next1()) @combinations.generate(test_base.default_test_combinations()) def testCacheTakeRepeat(self): dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2) expected_output = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate(test_base.default_test_combinations()) def testCacheRepeatEpochs(self): counter = variables.Variable(0) self.evaluate(counter.initializer) def increment_fn(x): counter.assign_add(1) return x dataset = dataset_ops.Dataset.range(10).map(increment_fn).cache().repeat(2) get_next = self.getNext(dataset, requires_initialization=True) # first epoch for i in range(10): self.assertEqual(i, self.evaluate(counter)) self.assertEqual(i, self.evaluate(get_next())) # second epoch for i in range(10): self.assertEqual(10, self.evaluate(counter)) self.assertEqual(i, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) def testCacheIterationEpochs(self): counter = variables.Variable(0) self.evaluate(counter.initializer) def increment_fn(x): counter.assign_add(1) return x dataset = dataset_ops.Dataset.range(10).map(increment_fn).cache() # first epoch i = 0 for elem in dataset: self.assertEqual(i, self.evaluate(elem)) i += 1 self.assertEqual(i, self.evaluate(counter)) # second epoch i = 0 for elem in dataset: self.assertEqual(10, self.evaluate(counter)) self.assertEqual(i, self.evaluate(elem)) i += 1 @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) def testCacheV2ResourceCapture(self): def make_dataset(): ids = dataset_ops.Dataset.range(10) ids = ids.cache() def interleave_fn(dataset, _): return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.interleave(functools.partial(interleave_fn, ids)) return dataset results = [] for elem in make_dataset(): results.append(elem.numpy()) self.assertAllEqual(results, range(10)) @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) def testCacheV2ConcurrentIterators(self): dataset = dataset_ops.Dataset.range(10).cache() it1 = iter(dataset) it2 = iter(dataset) for i in range(10): self.assertEqual(next(it1), i) self.assertEqual(next(it2), i) @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) def testCacheKnownCardinality(self): # Check that a dataset which produces random permutation of range(10) ends # up being cached when we read all of its element but do not reach EOF. dataset = dataset_ops.Dataset.range(10) dataset = dataset.shuffle(10, reshuffle_each_iteration=True).cache() it = iter(dataset) results = [] for _ in range(10): results.append(next(it)) it = iter(dataset) for i in range(10): self.assertEqual(next(it), results[i]) @combinations.generate(test_base.eager_only_combinations()) def testCheckpointFinishedCache(self): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = ds.cache() iterator = iter(ds) for i in range(num_elements): self.assertEqual(next(iterator).numpy(), i) ckpt = trackable_utils.Checkpoint(iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, self.get_temp_dir(), max_to_keep=1) manager.save() manager.restore_or_initialize() with self.assertRaises(StopIteration): next(iterator) @combinations.generate(test_base.eager_only_combinations()) def testCheckpointLargeCache(self): # Tensor of size 100M dataset = dataset_ops.Dataset.from_tensors( array_ops.ones((25, 1000, 1000), dtype=dtypes.float32)) # Repeat 25 times to exceed the 2G proto limit dataset = dataset.repeat(25) dataset = dataset.cache() # Iterate to fill the cache. iterator = iter(dataset) for _ in range(23): next(iterator) ckpt = trackable_utils.Checkpoint(iterator=iterator) manager = checkpoint_management.CheckpointManager( ckpt, self.get_temp_dir(), max_to_keep=1) manager.save()
class CacheCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def setUp(self): self.range_size = 10 self.num_repeats = 3 self.num_outputs = self.range_size * self.num_repeats self.cache_file_prefix = "test" def make_dataset_fn(self, is_memory): if is_memory: filename = "" else: filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix) def ds_fn(): return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat( self.num_repeats) return ds_fn def expected_outputs(self): return list(range(self.range_size)) * self.num_repeats @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointBeforeOneEpoch(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Generate 5 entries from iterator and save checkpoint. outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Generate 8 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, range(8)) outputs = outputs[:5] outputs.extend( self.gen_outputs( ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointAfterOneEpoch(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Generate 15 entries from iterator and save checkpoint. outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Generate 18 entries from iterator but save checkpoint after producing 15. outputs = self.gen_outputs( ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) outputs = list(range(10)) + list(range(5)) + self.gen_outputs( ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Generate 13 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) # Since we ran for more than one epoch, the cache was completely written. # The ckpt was saved when the iterator was in cache-write mode. Test that # the iterator falls back to read mode after restoring if the cache has # been completely written. outputs = list(range(5)) + self.gen_outputs( ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointUnusedWriterIterator(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Checkpoint before get_next is called even once. outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False) self.assertSequenceEqual(outputs, []) outputs = self.gen_outputs( ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testCheckpointUnusedMidwayWriterIterator(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Produce 5 elements and checkpoint. outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint, then produce no elements and checkpoint. outputs.extend( self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce rest of the elements. outputs.extend( self.gen_outputs( ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, list(range(10)) * 3) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testUnusedCheckpointError(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Produce 5 elements and save ckpt. outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) if is_memory: outputs = self.gen_outputs( ds_fn, [], self.num_outputs, verify_exhausted=False) self.assertSequenceEqual(outputs, self.expected_outputs()) else: # Since the complete cache has not been written, a new iterator which does # not restore the checkpoint will throw an error since there is a partial # cache shard. with self.assertRaises(errors.AlreadyExistsError): outputs = self.gen_outputs( ds_fn, [], self.num_outputs, verify_exhausted=False) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(is_memory=[True, False]))) def testIgnoreCheckpointIfCacheWritten(self, is_memory): ds_fn = self.make_dataset_fn(is_memory) # Produce 15 elements and save ckpt. This will write the complete cache. outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Build the iterator again but do not restore from ckpt. Since the cache # has already been written we should be able to use it. outputs = self.gen_outputs( ds_fn, [], self.num_outputs, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3)
class OptimizationTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testOptimizationStatefulFunction(self): dataset = dataset_ops.Dataset.range(10).map( lambda _: random_ops.random_uniform([])).batch(10) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) get_next = self.getNext(dataset) self.evaluate(get_next()) # TODO(b/123354468) @combinations.generate(test_base.graph_only_combinations()) def testOptimizationLargeInputFromTensor(self): input_t = array_ops.placeholder(dtypes.int32, (None, None, None)) dataset = dataset_ops.Dataset.from_tensors(input_t) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)}) self.evaluate(get_next) # TODO(b/123354468) @combinations.generate(test_base.graph_only_combinations()) def testOptimizationLargeInputFromTensorSlices(self): input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None)) dataset = dataset_ops.Dataset.from_tensor_slices(input_t) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)}) self.evaluate(get_next) @combinations.generate(test_base.default_test_combinations()) def testOptimizationNestedDataset(self): def flat_map_fn(_): dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"])) dataset = dataset.skip(0) # Should be removed by noop elimination dataset = dataset.cache() return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.flat_map(flat_map_fn) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.noop_elimination = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[0]) @combinations.generate(test_base.default_test_combinations()) def testOptimizationNestedDatasetWithModifiedRetval(self): def flat_map_fn(_): dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset.apply(testing.assert_next(["MapAndBatch"])) # Should be fused by map and batch fusion dataset = dataset.map(lambda x: x) dataset = dataset.batch(1) return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.flat_map(flat_map_fn) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.map_and_batch_fusion = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[[0]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(autotune=[True, False, None]), combinations.combine(map_parallelization=[True, False, None]))) def testOptimizationMapParallelization(self, autotune, map_parallelization): dataset = dataset_ops.Dataset.range(5) if autotune is not False and map_parallelization is not False: # pylint: disable=g-bool-id-comparison dataset = dataset.apply(testing.assert_next(["ParallelMap"])) else: dataset = dataset.apply(testing.assert_next(["Map"])) dataset = dataset.map(lambda x: x + 1) options = options_lib.Options() if autotune is not None: options.autotune.enabled = autotune if map_parallelization is not None: options.experimental_optimization.map_parallelization = ( map_parallelization) dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=list(range(1, 6))) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(existing_prefetch=[True, False]), combinations.combine(autotune=[True, False]), combinations.combine(set_env=[True, False]))) def testOptimizationInjectPrefetch(self, existing_prefetch, autotune, set_env): if set_env: os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "inject_prefetch" os.environ["TF_JOB_NAME"] = "test_job" dataset = dataset_ops.Dataset.range(5) dataset = dataset.map(lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE) if existing_prefetch: dataset = dataset.prefetch(1) if autotune and set_env and not existing_prefetch: dataset = dataset.apply(testing.assert_next(["Prefetch", "Root"])) else: dataset = dataset.apply(testing.assert_next(["Root"])) options = options_lib.Options() options.autotune.enabled = autotune dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=list(range(1, 6))) if set_env: del os.environ["TF_DATA_EXPERIMENT_OPT_IN"] del os.environ["TF_JOB_NAME"] # Reference variables are not supported in eager mode. @combinations.generate( combinations.times(test_base.graph_only_combinations(), _captured_refvar_test_combinations())) def testOptimizationWithCapturedRefVar(self, dataset_fn): """Tests that default optimizations are disabled with ref variables.""" variable = variable_scope.get_variable("v", initializer=0, use_resource=False) assign_op = variable.assign_add(1) unoptimized_dataset = dataset_fn(variable) options = options_lib.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.noop_elimination = True options.experimental_optimization.map_and_batch_fusion = True optimized_dataset = unoptimized_dataset.with_options(options) optimized_it = dataset_ops.make_initializable_iterator( optimized_dataset) # Check that outputs are the same in the optimized and unoptimized cases, # when the variable value is changing. unoptimized_it = dataset_ops.make_initializable_iterator( unoptimized_dataset) with ops.control_dependencies([assign_op]): unoptimized_output = unoptimized_it.get_next() optimized_output = optimized_it.get_next() self.evaluate(variable.initializer) self.evaluate((unoptimized_it.initializer, optimized_it.initializer)) while True: try: unoptimized, optimized = self.evaluate( (unoptimized_output, optimized_output)) self.assertEqual(unoptimized, optimized) except errors.OutOfRangeError: break
class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase): def _testDataset(self, dataset, function, predicate): expected_output = [] for x in range(10): r = function(x) if isinstance(r, tuple): b = predicate(*r) # Pass tuple as multiple arguments. else: b = predicate(r) if self.evaluate(b): expected_output.append(r) self.assertDatasetProduces(dataset, expected_output=expected_output) def _testMapAndFilterFusion(self, function, predicate): dataset = dataset_ops.Dataset.range(10).apply( testing.assert_next(["Map", "Filter", "Map"])).map(function).filter(predicate) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.map_and_filter_fusion = True dataset = dataset.with_options(options) self._testDataset(dataset, function, predicate) @combinations.generate(test_base.default_test_combinations()) def testMapAndFilterFusionScalar(self): identity = lambda x: x increment = lambda x: x + 1 minus_five = lambda x: x - 5 def increment_and_square(x): y = x + 1 return y * y functions = [identity, increment, minus_five, increment_and_square] take_all = lambda x: constant_op.constant(True) is_zero = lambda x: math_ops.equal(x, 0) is_odd = lambda x: math_ops.equal(x % 2, 0) greater = lambda x: math_ops.greater(x + 5, 0) predicates = [take_all, is_zero, is_odd, greater] for function in functions: for predicate in predicates: self._testMapAndFilterFusion(function, predicate) @combinations.generate(test_base.default_test_combinations()) def testMapAndFilterFusionTuple(self): replicate = lambda x: (x, x) with_two = lambda x: (x, 2) functions = [replicate, with_two] take_all = lambda x, y: constant_op.constant(True) is_zero = lambda x, y: math_ops.equal( x * math_ops.cast(y, dtypes.int64), 0) predicates = [take_all, is_zero] for function in functions: for predicate in predicates: self._testMapAndFilterFusion(function, predicate) @combinations.generate(test_base.default_test_combinations()) def testCapturedInputs(self): a = constant_op.constant(3, dtype=dtypes.int64) b = constant_op.constant(4, dtype=dtypes.int64) some_tensor = math_ops.mul(a, b) function = lambda x: x * x def predicate(y): return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor) # We are currently not supporting functions with captured inputs. dataset = dataset_ops.Dataset.range(10).apply( testing.assert_next(["Map", "Filter"])).map(function).filter(predicate) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.map_and_filter_fusion = True dataset = dataset.with_options(options) self._testDataset(dataset, function, predicate)
class TFRecordWriterTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): super(TFRecordWriterTest, self).setUp() self._num_records = 8 def writer_fn(self, filename, compression_type=""): input_dataset = readers.TFRecordDataset([filename], compression_type) return writers.TFRecordWriter(self._outputFilename(), compression_type).write(input_dataset) def _record(self, i): return compat.as_bytes("Record %d" % (i)) def _createFile(self, options=None): filename = self._inputFilename() writer = python_io.TFRecordWriter(filename, options) for i in range(self._num_records): writer.write(self._record(i)) writer.close() return filename def _inputFilename(self): return os.path.join(self.get_temp_dir(), "tf_record.in.txt") def _outputFilename(self): return os.path.join(self.get_temp_dir(), "tf_record.out.txt") @combinations.generate(test_base.default_test_combinations()) def testWrite(self): self.evaluate(self.writer_fn(self._createFile())) for i, r in enumerate( tf_record.tf_record_iterator(self._outputFilename())): self.assertAllEqual(self._record(i), r) @combinations.generate(test_base.default_test_combinations()) def testWriteZLIB(self): options = tf_record.TFRecordOptions( tf_record.TFRecordCompressionType.ZLIB) self.evaluate( self.writer_fn(self._createFile(options), compression_type="ZLIB")) for i, r in enumerate( tf_record.tf_record_iterator(self._outputFilename(), options=options)): self.assertAllEqual(self._record(i), r) @combinations.generate(test_base.default_test_combinations()) def testWriteGZIP(self): options = tf_record.TFRecordOptions( tf_record.TFRecordCompressionType.GZIP) self.evaluate( self.writer_fn(self._createFile(options), compression_type="GZIP")) for i, r in enumerate( tf_record.tf_record_iterator(self._outputFilename(), options=options)): self.assertAllEqual(self._record(i), r) @combinations.generate(test_base.default_test_combinations()) def testFailDataset(self): with self.assertRaises(TypeError): writers.TFRecordWriter(self._outputFilename(), "").write("whoops") @combinations.generate(test_base.default_test_combinations()) def testFailDType(self): input_dataset = dataset_ops.Dataset.from_tensors(10) with self.assertRaises(TypeError): writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset) @combinations.generate(test_base.default_test_combinations()) def testFailShape(self): input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]]) with self.assertRaises(TypeError): writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset) @combinations.generate(test_base.default_test_combinations()) def testSideEffect(self): def writer_fn(): input_dataset = readers.TFRecordDataset(self._createFile()) return writers.TFRecordWriter( self._outputFilename()).write(input_dataset) @function.defun def fn(): _ = writer_fn() return "hello" self.assertEqual(self.evaluate(fn()), b"hello") for i, r in enumerate( tf_record.tf_record_iterator(self._outputFilename())): self.assertAllEqual(self._record(i), r) @combinations.generate(test_base.default_test_combinations()) def testShard(self): filename = self._createFile() dataset = readers.TFRecordDataset([filename]) def reduce_func(key, dataset): shard_filename = string_ops.string_join( [filename, string_ops.as_string(key)]) writer = writers.TFRecordWriter(shard_filename) writer.write(dataset.map(lambda _, x: x)) return dataset_ops.Dataset.from_tensors(shard_filename) dataset = dataset.enumerate() dataset = dataset.apply( grouping.group_by_window(lambda i, _: i % 2, reduce_func, dtypes.int64.max)) get_next = self.getNext(dataset) for i in range(2): shard_filename = (filename + str(i)).encode() self.assertEqual(self.evaluate(get_next()), shard_filename) for j, r in enumerate( tf_record.tf_record_iterator(shard_filename)): self.assertAllEqual(self._record(i + 2 * j), r)
class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testBasic(self): components = ( np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), np.array([9.0, 10.0, 11.0, 12.0]) ) def dataset_fn(count=5, buffer_size=None, seed=0): repeat_dataset = ( dataset_ops.Dataset.from_tensor_slices(components).repeat(count)) if buffer_size: shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed) self.assertEqual( tuple([c.shape[1:] for c in components]), dataset_ops.get_legacy_output_shapes(shuffle_dataset)) return shuffle_dataset else: return repeat_dataset # First run without shuffling to collect the "ground truth". get_next = self.getNext(dataset_fn()) unshuffled_elements = [] for _ in range(20): unshuffled_elements.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) # Assert that the shuffled dataset has the same elements as the # "ground truth". get_next = self.getNext(dataset_fn(buffer_size=100, seed=37)) shuffled_elements = [] for _ in range(20): shuffled_elements.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertAllEqual(sorted(unshuffled_elements), sorted(shuffled_elements)) # Assert that shuffling twice with the same seeds gives the same sequence. get_next = self.getNext(dataset_fn(buffer_size=100, seed=37)) reshuffled_elements_same_seed = [] for _ in range(20): reshuffled_elements_same_seed.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertEqual(shuffled_elements, reshuffled_elements_same_seed) # Assert that shuffling twice with a different seed gives a different # permutation of the same elements. get_next = self.getNext(dataset_fn(buffer_size=100, seed=137)) reshuffled_elements_different_seed = [] for _ in range(20): reshuffled_elements_different_seed.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed) self.assertAllEqual( sorted(shuffled_elements), sorted(reshuffled_elements_different_seed)) # Assert that the shuffled dataset has the same elements as the # "ground truth" when the buffer size is smaller than the input # dataset. get_next = self.getNext(dataset_fn(buffer_size=2, seed=37)) reshuffled_elements_small_buffer = [] for _ in range(20): reshuffled_elements_small_buffer.append(self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertAllEqual( sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer)) # Test the case of shuffling an empty dataset. get_next = self.getNext(dataset_fn(count=0, buffer_size=100, seed=37)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(combinations.combine(tf_api_version=1, mode="graph")) def testSeedZero(self): """Test for same behavior when the seed is a Python or Tensor zero.""" iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.range(10).shuffle(10, seed=0)) get_next = iterator.get_next() elems = [] with self.cached_session() as sess: for _ in range(10): elems.append(sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder)) get_next = iterator.get_next() with self.cached_session() as sess: sess.run(iterator.initializer, feed_dict={seed_placeholder: 0}) for elem in elems: self.assertEqual(elem, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.default_test_combinations()) def testDefaultArguments(self): components = [0, 1, 2, 3, 4] dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle( 5).repeat() get_next = self.getNext(dataset) counts = collections.defaultdict(lambda: 0) for _ in range(10): for _ in range(5): counts[self.evaluate(get_next())] += 1 for i in range(5): self.assertEqual(10, counts[i]) @combinations.generate( combinations.times( test_base.graph_only_combinations(), combinations.combine(reshuffle=[True, False]), combinations.combine(graph_seed=38, op_seed=None) + combinations.combine(graph_seed=None, op_seed=42) + combinations.combine(graph_seed=38, op_seed=42))) def testShuffleSeed(self, reshuffle, graph_seed, op_seed): results = [] for _ in range(2): with ops.Graph().as_default() as g: random_seed.set_random_seed(graph_seed) dataset = dataset_ops.Dataset.range(10).shuffle( 10, seed=op_seed, reshuffle_each_iteration=reshuffle).repeat(3) iterator = dataset_ops.make_one_shot_iterator(dataset) next_element = iterator.get_next() run_results = [] with self.session(graph=g) as sess: for _ in range(30): run_results.append(sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) results.append(run_results) self.assertAllEqual(results[0], results[1]) # TODO(b/117581999): enable this test for eager-mode. @combinations.generate( combinations.times( test_base.graph_only_combinations(), combinations.combine( reshuffle=[True, False], initializable=[True, False]))) def testMultipleIterators(self, reshuffle, initializable): with ops.Graph().as_default() as g: dataset = dataset_ops.Dataset.range(100).shuffle( 10, reshuffle_each_iteration=reshuffle).repeat(3) if initializable: iterators = [dataset_ops.make_initializable_iterator(dataset) for _ in range(2)] else: iterators = [dataset_ops.make_one_shot_iterator(dataset) for _ in range(2)] results = [] with self.session(graph=g) as sess: for iterator in iterators: if initializable: sess.run(iterator.initializer) next_element = iterator.get_next() run_results = [] for _ in range(300): run_results.append(sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) results.append(run_results) self.assertNotEqual(results[0], results[1]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(reshuffle=[True, False], seed=[None, 42]))) def testReshuffleRepeatEpochs(self, reshuffle, seed): dataset = dataset_ops.Dataset.range(10).shuffle( 10, seed=seed, reshuffle_each_iteration=reshuffle).repeat(2) next_element = self.getNext(dataset) first_epoch = [] for _ in range(10): first_epoch.append(self.evaluate(next_element())) second_epoch = [] for _ in range(10): second_epoch.append(self.evaluate(next_element())) self.assertEqual(first_epoch == second_epoch, not reshuffle) @combinations.generate( combinations.times( combinations.combine(tf_api_version=2, mode="eager"), combinations.combine(reshuffle=[True, False], seed=[None, 42]))) def testReshuffleIterationEpochs(self, reshuffle, seed): dataset = dataset_ops.Dataset.range(10).shuffle( 10, seed=seed, reshuffle_each_iteration=reshuffle) first_epoch = [] for elem in dataset: first_epoch.append(elem.numpy()) second_epoch = [] for elem in dataset: second_epoch.append(elem.numpy()) self.assertEqual(first_epoch == second_epoch, not reshuffle) @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) def testShuffleV2ResourceCapture(self): def make_dataset(): ids = dataset_ops.Dataset.range(10) ids = ids.shuffle(1) def interleave_fn(dataset, _): return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.interleave(functools.partial(interleave_fn, ids)) return dataset results = [] for elem in make_dataset(): results.append(elem.numpy()) self.assertAllEqual(results, range(10)) @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(reshuffle=[True, False], seed=[None, 42]))) def testReshuffleSeparateTransformations(self, reshuffle, seed): dataset = dataset_ops.Dataset.range(10) first_epoch = [] for elem in dataset.shuffle( 10, seed=seed, reshuffle_each_iteration=reshuffle): first_epoch.append(elem.numpy()) second_epoch = [] for elem in dataset.shuffle( 10, seed=seed, reshuffle_each_iteration=reshuffle): second_epoch.append(elem.numpy()) self.assertEqual(first_epoch != second_epoch, seed is None) @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) def testShuffleV2InFunction(self): counter_var = variables.Variable(0) @function.defun def consume(): ds = dataset_ops.Dataset.range(10) ds = ds.shuffle(1) for _ in ds: counter_var.assign(counter_var + 1) consume() self.assertAllEqual(self.evaluate(counter_var), 10) @combinations.generate(test_base.default_test_combinations()) def testEmptyDataset(self): dataset = dataset_ops.Dataset.from_tensors(1) def map_fn(x): with ops.control_dependencies([check_ops.assert_equal(x, 0)]): return x dataset = dataset.map(map_fn) dataset = dataset.cache() dataset = dataset.shuffle(buffer_size=10).repeat() get_next = self.getNext(dataset) # First time around, we get an error for the failed assertion. with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) # Second time around, we get an EOF because the cached dataset is empty. with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next())
class BucketBySequenceLengthTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_no_padding=[True, False]))) def testBucketDropReminder(self, param_no_padding): boundaries = [10, 20, 30] batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25, 35] n_bucket_elements = [28, 7, 6, 5] n_expected_batches = 5 # Expected sequence lengths of the individual batches. expected_lengths = [] # Expected sum of all batches with an equal sequence length. # <seq-length>: <expected-total-sum> expected_sums = {} # Expected batch sizes of batches depending on the sequence length. # <seq-length>: [batch1_size, ..., batchN_size] expected_batch_sizes = {} for length, batch_size, bucket_elements in zip(lengths, batch_sizes, n_bucket_elements): # Calculate the expected sum across all batches of a specific sequence length. expected_sums[length] = \ (bucket_elements - bucket_elements % batch_size) * length # Calculate the expected occurrence of individual batch sizes. expected_batch_sizes[length] = \ [batch_size] * (bucket_elements // batch_size) # Calculate the expected occurrence of individual sequence lengths. expected_lengths.extend([length] * (bucket_elements // batch_size)) def build_dataset(sparse): def _generator(): # Produce 1 batch for each bucket elements = [] for bucket_elements, length in zip(n_bucket_elements, lengths): # Using only full sequences (opposed to the strategy employed in `testBucket`) makes # checking the sum a lot easier. record_len = length for _ in range(bucket_elements): elements.append([1] * record_len) random.shuffle(elements) for el in elements: yield (_format_record(el, sparse), ) dataset = dataset_ops.Dataset.from_generator( _generator, (_get_record_type(sparse), ), (_get_record_shape(sparse), )) if sparse: dataset = dataset.map(lambda x: (_to_sparse_tensor(x), )) return dataset def _test_bucket_by_padding(no_padding): dataset = build_dataset(sparse=no_padding) dataset = dataset.apply( grouping.bucket_by_sequence_length(_element_length_fn, boundaries, batch_sizes, no_padding=no_padding, drop_remainder=True)) get_next = self.getNext(dataset) batches = [] for _ in range(n_expected_batches): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) generated_lengths = [] # <seq-length>: <total-sum> generated_sums = {} # <seq-length>: [<batch_size>, ...] generated_batch_sizes = {} for length, batch_size, bucket_elements in zip( lengths, batch_sizes, n_bucket_elements): # Initialize the sum across all batches. generated_sums[length] = 0 # Initialize the individual batch sizes. generated_batch_sizes[length] = [] for batch in batches: shape = batch.dense_shape if no_padding else batch.shape length = shape[1] generated_lengths.append(length) batch_size = shape[0] generated_batch_sizes[length].append(batch_size) batch_sum = batch.values.sum() if no_padding else batch.sum() generated_sums[length] += batch_sum for l in lengths: # Make sure the sum of the batch contents is correct for the individual sequence lengths. self.assertEqual( generated_sums[l], expected_sums[l], "Tensor sums did not match! " "expected: {}, generated: {}".format( expected_sums, generated_sums)) # Make sure the individual batch sizes are generated as expected. self.assertEqual( sorted(generated_batch_sizes[l]), sorted(expected_batch_sizes[l]), "Batch-sizes did not match! " "expected: {}, generated: {}".format( sorted(expected_batch_sizes[l]), sorted(generated_batch_sizes[l]))) # Make sure the generated sequence lengths appear as often as expected. self.assertEqual( sorted(generated_lengths), sorted(expected_lengths), "The generated sequence lengths did not match! " "expected: {}, generated: {}".format( sorted(expected_lengths), sorted(generated_lengths))) _test_bucket_by_padding(param_no_padding) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_no_padding=[True, False]))) def testBucket(self, param_no_padding): boundaries = [10, 20, 30] batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25, 35] def build_dataset(sparse): def _generator(): # Produce 1 batch for each bucket elements = [] for batch_size, length in zip(batch_sizes, lengths): record_len = length - 1 for _ in range(batch_size): elements.append([1] * record_len) record_len = length random.shuffle(elements) for el in elements: yield (_format_record(el, sparse), ) dataset = dataset_ops.Dataset.from_generator( _generator, (_get_record_type(sparse), ), (_get_record_shape(sparse), )) if sparse: dataset = dataset.map(lambda x: (_to_sparse_tensor(x), )) return dataset def _test_bucket_by_padding(no_padding): dataset = build_dataset(sparse=no_padding) dataset = dataset.apply( grouping.bucket_by_sequence_length(_element_length_fn, boundaries, batch_sizes, no_padding=no_padding)) get_next = self.getNext(dataset) batches = [] for _ in range(4): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) batch_sizes_val = [] lengths_val = [] for batch in batches: shape = batch.dense_shape if no_padding else batch.shape batch_size = shape[0] length = shape[1] batch_sizes_val.append(batch_size) lengths_val.append(length) if not context.executing_eagerly(): sum_check = batch.values.sum( ) if no_padding else batch.sum() self.assertEqual(sum_check, batch_size * length - 1) self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) self.assertEqual(sorted(lengths), sorted(lengths_val)) _test_bucket_by_padding(param_no_padding) def testPadToBoundary(self): boundaries = [10, 20, 30] batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25] def element_gen(): # Produce 1 batch for each bucket elements = [] for batch_size, length in zip(batch_sizes[:-1], lengths): for _ in range(batch_size): elements.append([1] * length) random.shuffle(elements) for el in elements: yield (el, ) for _ in range(batch_sizes[-1]): el = [1] * (boundaries[-1] + 5) yield (el, ) element_len = lambda el: array_ops.shape(el)[0] dataset = dataset_ops.Dataset.from_generator( element_gen, (dtypes.int64, ), ([None], )).apply( grouping.bucket_by_sequence_length( element_len, boundaries, batch_sizes, pad_to_bucket_boundary=True)) get_next = self.getNext(dataset) batches = [] for _ in range(3): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaisesOpError("bucket_boundaries"): self.evaluate(get_next()) batch_sizes_val = [] lengths_val = [] for batch in batches: batch_size = batch.shape[0] length = batch.shape[1] batch_sizes_val.append(batch_size) lengths_val.append(length) batch_sizes = batch_sizes[:-1] self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) self.assertEqual([boundary - 1 for boundary in sorted(boundaries)], sorted(lengths_val)) def testPadToBoundaryNoExtraneousPadding(self): boundaries = [3, 7, 11] batch_sizes = [2, 2, 2, 2] lengths = range(1, 11) def element_gen(): for length in lengths: yield ([1] * length, ) element_len = lambda element: array_ops.shape(element)[0] dataset = dataset_ops.Dataset.from_generator( element_gen, (dtypes.int64, ), ([None], )).apply( grouping.bucket_by_sequence_length( element_len, boundaries, batch_sizes, pad_to_bucket_boundary=True)) get_next = self.getNext(dataset) batches = [] for _ in range(5): batch, = self.evaluate(get_next()) batches.append(batch) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) self.assertAllEqual(batches[0], [[1, 0], [1, 1]]) self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0]]) self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1]]) self.assertAllEqual( batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) self.assertAllEqual( batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_no_padding=[True, False]))) def testTupleElements(self, param_no_padding): def build_dataset(sparse): def _generator(): text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] label = [1, 2, 1, 2] for x, y in zip(text, label): yield (_format_record(x, sparse), y) dataset = dataset_ops.Dataset.from_generator( generator=_generator, output_types=(_get_record_type(sparse), dtypes.int32), output_shapes=(_get_record_shape(sparse), tensor_shape.TensorShape([]))) if sparse: dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y)) return dataset def _test_tuple_elements_by_padding(no_padding): dataset = build_dataset(sparse=no_padding) dataset = dataset.apply( grouping.bucket_by_sequence_length( element_length_func=_element_length_fn, bucket_batch_sizes=[2, 2, 2], bucket_boundaries=[0, 8], no_padding=no_padding)) shapes = dataset_ops.get_legacy_output_shapes(dataset) self.assertEqual([None, None], shapes[0].as_list()) self.assertEqual([None], shapes[1].as_list()) _test_tuple_elements_by_padding(param_no_padding) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(param_drop_remainder=[True, False]))) def testBucketSparse(self, param_drop_remainder): # pylint: disable=g-doc-args """Tests bucketing of sparse tensors (case where `no_padding` == True). Test runs on following dataset: [ [0], [0, 1], [0, 1, 2] ... [0, ..., max_len - 1] ] Sequences are bucketed by length and batched with `batch_size` < `bucket_size`. """ min_len = 0 max_len = 100 batch_size = 7 bucket_size = 10 def _build_dataset(): input_data = [range(i + 1) for i in range(min_len, max_len)] def generator_fn(): for record in input_data: yield _format_record(record, sparse=True) dataset = dataset_ops.Dataset.from_generator( generator=generator_fn, output_types=_get_record_type(sparse=True)) dataset = dataset.map(_to_sparse_tensor) return dataset def _compute_expected_batches(drop_remainder): """Computes expected batch outputs and stores in a set.""" all_expected_sparse_tensors = set() for bucket_start_len in range(min_len, max_len, bucket_size): if drop_remainder: batch_offsets = [0] else: batch_offsets = range(0, bucket_size, batch_size) for batch_offset in batch_offsets: batch_start_len = bucket_start_len + batch_offset batch_end_len = min(batch_start_len + batch_size, bucket_start_len + bucket_size) expected_indices = [] expected_values = [] for length in range(batch_start_len, batch_end_len): for val in range(length + 1): expected_indices.append( (length - batch_start_len, val)) expected_values.append(val) expected_sprs_tensor = (tuple(expected_indices), tuple(expected_values)) all_expected_sparse_tensors.add(expected_sprs_tensor) return all_expected_sparse_tensors def _compute_batches(dataset): """Computes actual batch outputs of dataset and stores in a set.""" batch = self.getNext(dataset) all_sparse_tensors = set() with self.assertRaises(errors.OutOfRangeError): while True: output = self.evaluate(batch()) sprs_tensor = (tuple([ tuple(idx) for idx in output.indices ]), tuple(output.values)) all_sparse_tensors.add(sprs_tensor) return all_sparse_tensors dataset = _build_dataset() boundaries = range(min_len + bucket_size + 1, max_len, bucket_size) dataset = dataset.apply( grouping.bucket_by_sequence_length( _element_length_fn, boundaries, [batch_size] * (len(boundaries) + 1), no_padding=True, drop_remainder=param_drop_remainder)) batches = _compute_batches(dataset) expected_batches = _compute_expected_batches(param_drop_remainder) self.assertEqual(batches, expected_batches)
class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, parameterized.TestCase): def setUp(self): super(AutoShardDatasetTest, self).setUp() self._num_files = 10 self._num_records = 10 self.test_filenames = self._createFiles() def getAllDatasetElements(self, dataset): actual = [] next_fn = self.getNext(dataset) while True: try: actual.append(self.evaluate(next_fn())) except errors.OutOfRangeError: break return actual def assertDatasetProducesWithShuffle(self, dataset, expected, batch, num_examples, shuffle): if shuffle: actual = [] next_fn = self.getNext(dataset) for _ in range(num_examples): elem = self.evaluate(next_fn()) if isinstance(elem, tuple): actual.extend(elem) else: actual.extend(elem.tolist()) self.assertCountEqual(actual, expected) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_fn()) else: self.assertDatasetProduces(dataset, list(chunk(expected, batch))) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testFlatMapReaderPipeline(self, shuffle): dataset = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=shuffle) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (3, 8) for r in range(0, 10) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(batch_size=[1, 3, 10]))) def testDatasetOfReaderDatasetsPipeline(self, batch_size): # This tests a scenario where a list_files main return multiple files # due to the glob containing wildcards. def batch(iterator, n): l = len(iterator) for i in range(0, l, n): yield iterator[i:min(i + n, l)] datasets = [] for files in batch(self.test_filenames, batch_size): datasets.append( dataset_ops.Dataset.list_files(files, shuffle=False).map( core_readers.TFRecordDataset)) dataset = dataset_ops.Dataset.from_tensor_slices(datasets) dataset = dataset.flat_map(lambda x: x) # Simulate additional ops in between flat_map and interleave. This should be # a no-op since if ShardDataset is placed right after flat_map, we will only # have two datasets left at this point. dataset = dataset.prefetch(1) dataset = dataset.prefetch(1) dataset = dataset.interleave( lambda x: x, cycle_length=1, num_parallel_calls=1) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testZipReaderPipeline(self): dataset1 = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=False) dataset1 = dataset1.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset2 = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=False) dataset2 = dataset2.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ (b"Record %d of file %d" % (r, f), b"Record %d of file %d" % (r, f)) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProduces(dataset, expected) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testConcatenateReaderPipeline(self, shuffle): dataset1 = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=shuffle) dataset1 = dataset1.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset1 = dataset1.batch(5) dataset2 = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=shuffle) dataset2 = dataset2.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset2 = dataset2.batch(5) dataset = dataset1.concatenate(dataset2) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] expected += expected self.assertDatasetProducesWithShuffle(dataset, expected, 5, 8, shuffle) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testPipelineWithMap(self, shuffle): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate(test_base.default_test_combinations()) def testDirectFilenameTFRecordReaderPipeline(self): dataset = core_readers.TFRecordDataset(self.test_filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testValidPipelineWithRangeDataset(self, shuffle): dataset = dataset_ops.Dataset.range(self._num_files) dataset = dataset.map(lambda n: string_ops.string_join( # pylint:disable=g-long-lambda [self.get_temp_dir(), string_ops.string_format("/tf_record.{}.txt", [n])])) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(params=[(1, 0, 10, 10), (2, 1, 20, 5), (10, 1, 1, 10)]))) def testStandardReaderPipeline(self, params): num_epochs, index, batch_size, parallel_reads = params dataset = readers.make_tf_record_dataset( file_pattern=self.test_filenames, num_epochs=num_epochs, batch_size=batch_size, parser_fn=None, num_parallel_reads=parallel_reads, drop_final_batch=True, shuffle=False) dataset = distribute._AutoShardDataset(dataset, 2, index) outputs = self.getNext(dataset) self._verify_records( outputs, batch_size=batch_size, file_index=[i for i in range(index, self._num_records, 2)], num_epochs=num_epochs, interleave_cycle_length=parallel_reads, drop_final_batch=True, use_parser_fn=None) with self.assertRaises(errors.OutOfRangeError): self.evaluate(outputs()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(shuffle=[True, False]))) def testSampleResNetPipeline(self, shuffle): dataset = dataset_ops.Dataset.list_files( self.test_filenames, shuffle=shuffle) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for r in range(0, 10) for f in (3, 8) ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ distribute_options.AutoShardPolicy.DATA, distribute_options.AutoShardPolicy.AUTO ]))) def testShardByDataBeforePrefetch(self, sharding_policy): dataset = dataset_ops.Dataset.range(4) dataset = dataset.apply(testing.assert_next(["Shard", "Prefetch"])) dataset = dataset.prefetch(1) options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = sharding_policy dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 2, 0) self.assertDatasetProduces(dataset, [0, 2]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.times(combinations.combine( sharding_policy=[distribute_options.AutoShardPolicy.DATA, distribute_options.AutoShardPolicy.FILE]), combinations.combine(shuffle=[True, False])))) def testReplicateAndShardProduceDisjointData(self, shuffle, sharding_policy): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=shuffle) dataset = dataset.flat_map(core_readers.TFRecordDataset) graph_def = dataset._as_serialized_graph( strip_device_assignment=True, external_state_policy=distribute_options.ExternalStatePolicy.WARN) options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = sharding_policy ds1 = distribute._RemoteDataset(graph_def, "/device:CPU:0", dataset.element_spec) ds2 = distribute._RemoteDataset(graph_def, "/device:CPU:0", dataset.element_spec) ds1 = ds1.with_options(options) ds2 = ds2.with_options(options) ds1 = distribute._AutoShardDataset(ds1, 2, 0) ds2 = distribute._AutoShardDataset(ds2, 2, 1) elems1 = set(self.getAllDatasetElements(ds1)) elems2 = set(self.getAllDatasetElements(ds2)) self.assertEmpty(elems1.intersection(elems2)) @combinations.generate(test_base.default_test_combinations()) def testWorkersGreaterThanNumFilesWithDataSharding(self): options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = ( distribute_options.AutoShardPolicy.DATA) dataset = core_readers._TFRecordDataset(self.test_filenames) dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 5, 0) # Should return "Record (0,5) of file (0 --> 9)" since we are sharding by # individual elements, we should be able to get some data from all files. expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testAutoshardPolicyOff(self): options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = ( distribute_options.AutoShardPolicy.OFF) dataset = core_readers._TFRecordDataset(self.test_filenames) dataset = dataset.with_options(options) dataset = distribute._AutoShardDataset(dataset, 5, 0) # Should return every record in every file since autosharding is turned off. expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testFileShardingWithoutReaderDatasetOp(self): options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = ( distribute_options.AutoShardPolicy.FILE) dataset = dataset_ops.Dataset.range(1024) dataset = dataset.with_options(options) # We are specifying that we want a file sharding policy, and this pipeline # doesn't start with file reading, so we should error out. with self.assertRaises(errors.NotFoundError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)()) @combinations.generate(test_base.default_test_combinations()) def testWorkersGreaterThanNumFiles(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames) dataset = dataset.apply( interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 500, 499) self.assertDatasetProduces(dataset, []) @combinations.generate(test_base.default_test_combinations()) def testTFRecordReaderWithDirectFileNames(self): # Using `_TFRecordDataset` creates a raw op rather than wrapping it around # a flat_map automatically. dataset = core_readers._TFRecordDataset(self.test_filenames) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testTFRecordReaderWithDirectFileNamesAndShapes(self): # Using `_TFRecordDataset` creates a raw op rather than wrapping it around # a flat_map automatically. dataset = core_readers._TFRecordDataset(self.test_filenames) # BatchDataset contains `output_types` and `output_shapes` dataset = dataset.batch(5) dataset = distribute._AutoShardDataset(dataset, 2, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 5) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5))) @combinations.generate(test_base.default_test_combinations()) def testShardOutOfRange(self): dataset = dataset_ops.Dataset.range(5) with self.assertRaises(errors.InvalidArgumentError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)()) @combinations.generate(test_base.default_test_combinations()) def testShardOutOfRangeEmptyDataset(self): dataset = dataset_ops.Dataset.range(0) with self.assertRaises(errors.OutOfRangeError): dataset = distribute._AutoShardDataset(dataset, 10, 0) self.evaluate(self.getNext(dataset)()) @combinations.generate(test_base.default_test_combinations()) def testNoReaderPipelines(self): dataset = dataset_ops.Dataset.range(1024) dataset = distribute._AutoShardDataset(dataset, 2, 0) self.assertDatasetProduces(dataset, [i for i in range(1024) if i % 2 == 0]) @combinations.generate(test_base.default_test_combinations()) def testUnknownOpInPipelineStillShardsAtTheEnd(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.apply(unique.unique()) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in (0, 5) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testInvalidWorkerIndex(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) with self.assertRaises(errors.InvalidArgumentError): dataset = distribute._AutoShardDataset(dataset, 2, 2) self.evaluate(self.getNext(dataset)()) @combinations.generate(test_base.default_test_combinations()) def testAssertCardinality(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = dataset.apply(cardinality.assert_cardinality(42)) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5))) @combinations.generate(test_base.default_test_combinations()) def testMaxIntraOpParallelism(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5))) @combinations.generate(test_base.default_test_combinations()) def testPrivateThreadpool(self): dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5))) @combinations.generate(test_base.default_test_combinations()) def testMakeBatchedFeaturesDataset(self): files = 2 records_per_file = 5 def make_record(file_index): example = example_pb2.Example( features=feature_pb2.Features( feature={ "file": feature_pb2.Feature( int64_list=feature_pb2.Int64List(value=[file_index])), })) return example.SerializeToString() filenames = [] for file_index in range(files): filename = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % file_index) filenames.append(filename) writer = python_io.TFRecordWriter(filename) for _ in range(records_per_file): writer.write(make_record(file_index)) writer.close() dataset = readers.make_batched_features_dataset( file_pattern=filenames, batch_size=records_per_file, features={ "file": parsing_ops.FixedLenFeature([], dtypes.int64), }, reader=core_readers.TFRecordDataset, num_epochs=1) # We should shard at the file level, so that all records come from file 0. dataset = distribute._AutoShardDataset(dataset, 2, 0) dataset = dataset.unbatch() output = self.getDatasetOutput(dataset) files = [elem["file"] for elem in output] self.assertEqual(files, [0] * records_per_file)
class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( input_values=[[4, 5, 6]], cycle_length=1, block_length=1, expected_elements=[[ 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6 ]]) + combinations.combine( input_values=[[4, 5, 6]], cycle_length=2, block_length=1, expected_elements=[[ 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5, 6, 6 ]]) + combinations.combine( input_values=[[4, 5, 6]], cycle_length=2, block_length=3, expected_elements=[[ 4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6 ]]) + combinations.combine( input_values=[[4, 5, 6]], cycle_length=7, block_length=2, expected_elements=[[ 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6 ]]) + combinations.combine(input_values=[[4, 0, 6]], cycle_length=2, block_length=1, expected_elements=[[ 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6 ]]))) def testPythonImplementation(self, input_values, cycle_length, block_length, expected_elements): input_lists = _repeat(input_values, 2) for expected, produced in zip( expected_elements, _interleave(input_lists, cycle_length, block_length)): self.assertEqual(expected, produced) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=1, block_length=3, num_parallel_calls=[None, 1]) + combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=2, block_length=[1, 3], num_parallel_calls=[None, 1, 2]) + combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=7, block_length=2, num_parallel_calls=[None, 1, 3, 5, 7]) + combinations.combine(input_values=[np.int64([4, 5, 6, 7])], cycle_length=dataset_ops.AUTOTUNE, block_length=3, num_parallel_calls=[None, 1]) + combinations.combine( input_values=[np.int64([]), np.int64([0, 0, 0])], cycle_length=2, block_length=3, num_parallel_calls=[None]) + combinations.combine(input_values=[np.int64([4, 0, 6])], cycle_length=2, block_length=3, num_parallel_calls=[None, 1, 2]))) def testInterleaveDataset(self, input_values, cycle_length, block_length, num_parallel_calls): count = 2 dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat( count).interleave( lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), cycle_length, block_length, num_parallel_calls) expected_output = [ element for element in _interleave(_repeat(input_values, count), cycle_length, block_length) ] self.assertDatasetProduces(dataset, expected_output) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( input_values=[np.float32([1., np.nan, 2., np.nan, 3.])], cycle_length=1, block_length=3, num_parallel_calls=[None, 1]) + combinations.combine( input_values=[np.float32([1., np.nan, 2., np.nan, 3.])], cycle_length=2, block_length=[1, 3], num_parallel_calls=[None, 1, 2]) + combinations.combine( input_values=[np.float32([1., np.nan, 2., np.nan, 3.])], cycle_length=7, block_length=2, num_parallel_calls=[None, 1, 3, 5, 7]))) def testInterleaveDatasetError(self, input_values, cycle_length, block_length, num_parallel_calls): dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map( lambda x: array_ops.check_numerics(x, "message")).interleave( dataset_ops.Dataset.from_tensors, cycle_length, block_length, num_parallel_calls) get_next = self.getNext(dataset) for value in input_values: if np.isnan(value): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(get_next()) else: self.assertEqual(value, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testInterleaveSparse(self): def _map_fn(i): return sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) def _interleave_fn(x): return dataset_ops.Dataset.from_tensor_slices( sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) dataset = dataset_ops.Dataset.range(10).map(_map_fn).interleave( _interleave_fn, cycle_length=1) get_next = self.getNext(dataset) for i in range(10): for j in range(2): expected = [i, 0] if j % 2 == 0 else [0, -i] self.assertAllEqual(expected, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=1, block_length=3, num_parallel_calls=1) + combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=2, block_length=[1, 3], num_parallel_calls=[1, 2]) + combinations.combine(input_values=[np.int64([4, 5, 6])], cycle_length=7, block_length=2, num_parallel_calls=[1, 3, 5, 7]) + combinations.combine(input_values=[np.int64([4, 5, 6, 7])], cycle_length=dataset_ops.AUTOTUNE, block_length=3, num_parallel_calls=1) + combinations.combine(input_values=[np.int64([4, 0, 6])], cycle_length=2, block_length=3, num_parallel_calls=[1, 2]))) def testSloppyInterleaveDataset(self, input_values, cycle_length, block_length, num_parallel_calls): count = 2 dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat( count).interleave( lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), cycle_length, block_length, num_parallel_calls) options = dataset_ops.Options() options.experimental_deterministic = False dataset = dataset.with_options(options) expected_output = [ element for element in _interleave(_repeat(input_values, count), cycle_length, block_length) ] get_next = self.getNext(dataset) actual_output = [] for _ in range(len(expected_output)): actual_output.append(self.evaluate(get_next())) self.assertAllEqual(expected_output.sort(), actual_output.sort()) @combinations.generate(test_base.default_test_combinations()) def testInterleaveMap(self): dataset = dataset_ops.Dataset.range(100) def interleave_fn(x): dataset = dataset_ops.Dataset.from_tensors(x) return dataset.map(lambda x: x + x) dataset = dataset.interleave(interleave_fn, cycle_length=5) dataset = dataset.interleave(interleave_fn, cycle_length=5) self.assertDatasetProduces(dataset, [4 * x for x in range(100)]) @combinations.generate(test_base.default_test_combinations()) def testParallelInterleaveCached(self): dataset = dataset_ops.Dataset.range(5) dataset = dataset.cache(os.path.join(self.get_temp_dir(), "cache_dir")) def interleave_fn(x): return dataset_ops.Dataset.from_tensors(x) dataset = dataset.interleave(interleave_fn, cycle_length=2, num_parallel_calls=2) self.assertDatasetProduces(dataset, list(range(5))) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(local_determinism=[None, True, False], global_determinism=[True, False]))) def testDeterminismConfiguration(self, local_determinism, global_determinism): expect_determinism = local_determinism or (local_determinism is None and global_determinism) elements = list(range(1000)) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds dataset = dataset_ops.Dataset.from_tensor_slices(elements) dataset = dataset.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10, deterministic=local_determinism) opts = dataset_ops.Options() opts.experimental_deterministic = global_determinism dataset = dataset.with_options(opts) return dataset self.checkDeterminism(dataset_fn, expect_determinism, elements)
class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( count=20, size=[10, 14, 17], shift=[7, 14], stride=[1, 2, 6], drop_remainder=[True, False]) + combinations.combine( count=[0, 1], size=10, shift=4, stride=1, drop_remainder=[True, False]))) def testWindowDataset(self, count, size, shift, stride, drop_remainder): """Tests a dataset that slides a window its input elements.""" components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) def _flat_map_fn(x, y, z): return dataset_ops.Dataset.zip((x.batch(batch_size=size), y.batch(batch_size=size), z.batch(batch_size=size))) dataset = dataset_ops.Dataset.from_tensor_slices(components).map( _map_fn).repeat(count).window( size=size, shift=shift, stride=stride, drop_remainder=drop_remainder).flat_map(_flat_map_fn) get_next = self.getNext(dataset) self.assertEqual([[None] + list(c.shape[1:]) for c in components], [ts.as_list() for ts in nest.flatten( dataset_ops.get_legacy_output_shapes(dataset))]) num_full_batches = max(0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1) for i in range(num_full_batches): result = self.evaluate(get_next()) for component, result_component in zip(components, result): for j in range(size): self.assertAllEqual(component[(i * shift + j * stride) % 7]**2, result_component[j]) if not drop_remainder: num_partial_batches = (count * 7) // shift + ( (count * 7) % shift > 0) - num_full_batches for i in range(num_partial_batches): result = self.evaluate(get_next()) for component, result_component in zip(components, result): remaining = (count * 7) - ((num_full_batches + i) * shift) num_elements = remaining // stride + ((remaining % stride) > 0) for j in range(num_elements): self.assertAllEqual( component[((num_full_batches + i) * shift + j * stride) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(count=20, size=0, shift=3, stride=1) + combinations.combine(count=20, size=3, shift=0, stride=1) + combinations.combine(count=20, size=3, shift=3, stride=0))) def testWindowDatasetInvalid(self, count, size, shift, stride): with self.assertRaises(errors.InvalidArgumentError): ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window( size=size, shift=shift, stride=stride).flat_map(lambda x: x.batch(batch_size=size)) self.evaluate(ds._variant_tensor) @combinations.generate(test_base.default_test_combinations()) def testWindowDifferentNestedStructures(self): ds = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])).window(2) self.getNext(ds) ds = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2]}).window(2) self.getNext(ds) @combinations.generate(test_base.default_test_combinations()) def testWindowSparse(self): def _sparse(i): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) dataset = dataset_ops.Dataset.range(10).map(_sparse).window( size=5, shift=3, drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5)) num_batches = (10 - 5) // 3 + 1 expected_output = [ sparse_tensor.SparseTensorValue( indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4], dense_shape=[5, 1]) for i in range(num_batches) ] self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate(test_base.default_test_combinations()) def testWindowSparseWithDifferentDenseShapes(self): def _sparse(i): return sparse_tensor.SparseTensorValue( indices=array_ops.expand_dims( math_ops.range(i, dtype=dtypes.int64), 1), values=array_ops.fill([math_ops.cast(i, dtypes.int32)], i), dense_shape=[i]) dataset = dataset_ops.Dataset.range(10).map(_sparse).window( size=5, shift=3, drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5)) expected_output = [] num_batches = (10 - 5) // 3 + 1 for i in range(num_batches): expected_indices = [] expected_values = [] for j in range(5): for k in range(i * 3 + j): expected_indices.append([j, k]) expected_values.append(i * 3 + j) expected_output.append( sparse_tensor.SparseTensorValue( indices=expected_indices, values=expected_values, dense_shape=[5, i * 3 + 5 - 1])) self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate(test_base.default_test_combinations()) def testNestedWindowSparse(self): def _sparse(i): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) dataset = dataset_ops.Dataset.range(10).map(_sparse).window( size=4, shift=2, drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window( size=3, shift=1, drop_remainder=True).flat_map(lambda x: x.batch(batch_size=3)) expected_output = [ sparse_tensor.SparseTensorValue( indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7], dense_shape=[3, 4, 1]), sparse_tensor.SparseTensorValue( indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9], dense_shape=[3, 4, 1]) ] self.assertDatasetProduces(dataset, expected_output=expected_output) @combinations.generate(test_base.default_test_combinations()) def testWindowShapeError(self): def generator(): yield [1.0, 2.0, 3.0] yield [4.0, 5.0, 6.0] yield [7.0, 8.0, 9.0, 10.0] dataset = dataset_ops.Dataset.from_generator( generator, dtypes.float32, output_shapes=[None]).window( size=3, shift=1).flat_map(lambda x: x.batch(batch_size=3)) self.assertDatasetProduces( dataset, expected_error=( errors.InvalidArgumentError, r"Cannot batch tensors with different shapes in component 0. " r"First element had shape \[3\] and element 2 had shape \[4\].")) @combinations.generate(test_base.default_test_combinations()) def testWindowIgnoreErrors(self): input_values = np.float32([1., np.nan, 2., np.nan, 3.]) dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map( lambda x: array_ops.check_numerics(x, "message")).window( size=2, shift=2, stride=2, drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2)) self.assertDatasetProduces( dataset, expected_output=[np.float32([1., 2.]), np.float32([2., 3.])]) @combinations.generate(test_base.default_test_combinations()) def testNestedOutput(self): if not context.executing_eagerly(): self.skipTest("self.evaluate() does not work with a dataset") dataset = dataset_ops.Dataset.range(100) dataset = dataset_ops.Dataset.zip((dataset, dataset)).window(10) for i, nested_dataset in enumerate(dataset): x, y = nested_dataset self.assertDatasetProduces(x, range(i*10, (i+1)*10)) self.assertDatasetProduces(y, range(i*10, (i+1)*10))
class LegacySnapshotDatasetTest( reader_dataset_ops_test_base.TFRecordDatasetTestBase, parameterized.TestCase): def setUp(self): super(LegacySnapshotDatasetTest, self).setUp() self.removeTFRecords() tmpdir = self.get_temp_dir() tmpdir = os.path.join(tmpdir, "snapshot") os.mkdir(tmpdir) self.snapshot_dir = tmpdir def tearDown(self): super(LegacySnapshotDatasetTest, self).tearDown() shutil.rmtree(self.snapshot_dir) def removeTFRecords(self): for filename in self.test_filenames: os.remove(filename) self.test_filenames = [] def setUpTFRecord(self, num_files=10, num_records=10): self._num_files = num_files self._num_records = num_records self.test_filenames = self._createFiles() def makeSnapshotDirectory(self): return self.snapshot_dir def assertSnapshotDirectoryContains(self, directory, num_fingerprints, num_runs_per_fp, num_snapshot_files): dirlist_raw = os.listdir(directory) dirlist = [] # Ignore the graphdef pbtxts we write for debugging purposes. for i in range(len(dirlist_raw)): if not dirlist_raw[i].endswith("-graph.pbtxt"): dirlist.append(dirlist_raw[i]) self.assertLen(dirlist, num_fingerprints) for i in range(num_fingerprints): fingerprint_dir = os.path.join(directory, dirlist[i]) fingerprint_dir_list = sorted(os.listdir(fingerprint_dir)) self.assertLen(fingerprint_dir_list, num_runs_per_fp + 1) self.assertEqual(fingerprint_dir_list[num_runs_per_fp], "snapshot.metadata") for j in range(num_runs_per_fp): run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j]) run_dirlist = sorted(os.listdir(run_dir)) self.assertLen(run_dirlist, num_snapshot_files) file_counter = 0 for filename in run_dirlist: self.assertEqual(filename, "%08d.snapshot" % file_counter) file_counter += 1 @combinations.generate(test_base.default_test_combinations()) def testWriteDifferentPipelinesInOneDirectory(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1000))) dataset = dataset_ops.Dataset.range(1001) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1001))) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotMultipleSimultaneous(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir)) next1 = self.getNext(dataset1) dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir)) next2 = self.getNext(dataset2) for i in range(0, 1000): self.assertEqual(i, self.evaluate(next1())) self.assertEqual(i, self.evaluate(next2())) # we check that only one copy of the metadata has been written, and the # one that lost the race would be in passthrough mode. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testGetNextCreatesDir(self): tmpdir = self.snapshot_dir # We create two iterators but call getNext on only one. dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir)) next1 = self.getNext(dataset1) dataset2 = dataset_ops.Dataset.range(1001) dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir)) _ = self.getNext(dataset2) for _ in range(1000): self.evaluate(next1()) # We check that only one directory is created. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testWriteSnapshotSimpleSuccessful(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset, list(range(1000))) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testWriteSnapshotRepeatAfterwards(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testWriteSnapshotMixTypes(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) def map_fn(x): return (x, string_ops.as_string(x), string_ops.as_string(2 * x), 2 * x) dataset = dataset.map(map_fn) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) dataset = dataset.repeat(10) expected = [] for i in range(10): expected.append((i, str(i), str(2 * i), 2 * i)) self.assertDatasetProduces(dataset, expected * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testSpecifySnapshotNameWriteAndRead(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, snapshot_name="my_custom_snapshot")) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) self.assertTrue( os.path.exists(os.path.join(tmpdir, "custom-my_custom_snapshot"))) self.assertTrue( os.path.exists( os.path.join(tmpdir, "custom-my_custom_snapshot", "custom"))) @combinations.generate(test_base.default_test_combinations()) def testForcePassthroughMode(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, mode="passthrough")) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 0, 0, 0) @combinations.generate(test_base.default_test_combinations()) def testForceWriteMode(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir, mode="write")) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) # We will end up writing 10 different runs. self.assertSnapshotDirectoryContains(tmpdir, 1, 10, 1) @combinations.generate(test_base.default_test_combinations()) def testForceReadMode(self): tmpdir = self.snapshot_dir # We write a copy of the snapshot first. dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, mode="write", snapshot_name="my_custom_snapshot")) self.assertDatasetProduces(dataset, list(range(10))) # We move the run to a new name. shutil.move(os.path.join(tmpdir, "custom-my_custom_snapshot"), os.path.join(tmpdir, "custom-my_custom_snapshot_2")) # Even though the snapshot.metadata is pointing to the old run that no # longer exists after we moved, we force it to read from the run we specify. dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, mode="read", snapshot_name="my_custom_snapshot_2")) self.assertDatasetProduces(dataset, list(range(10))) # We should still have one snapshot and one run. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testForceReadNonexistentSnapshot(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) with self.assertRaises(errors.NotFoundError): dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, mode="read")) get_next = self.getNext(dataset) self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testForceReadNonexistentNamedSnapshot(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) with self.assertRaises(errors.NotFoundError): dataset = dataset.apply( snapshot.legacy_snapshot( tmpdir, mode="read", snapshot_name="my_nonexistent_snapshot")) get_next = self.getNext(dataset) self.evaluate(get_next()) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testReadSnapshotBackAfterWrite(self, compression): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset2, expected) @combinations.generate(test_base.default_test_combinations()) def testReadShuffledSnapshotAfterWrite(self): self.setUpTFRecord(num_files=10, num_records=50) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 50) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=100)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=100, shuffle_on_read=True)) next2 = self.getNext(dataset2) res1 = self.evaluate(next2()) res2 = self.evaluate(next2()) res3 = self.evaluate(next2()) res4 = self.evaluate(next2()) res5 = self.evaluate(next2()) # make sure that we don't read the file back in the same order. self.assertNotEqual([res1, res2, res3, res4, res5], expected[0:5]) # make sure all the elements are still there dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=100, shuffle_on_read=True)) self.assertDatasetProduces(dataset3, expected, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testReadShuffledSnapshotWithSeedAfterWrite(self): self.setUpTFRecord(num_files=10, num_records=50) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 50) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10, shuffle_on_read=True, shuffle_seed=123456)) next2 = self.getNext(dataset2) dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10, shuffle_on_read=True, shuffle_seed=123456)) next3 = self.getNext(dataset3) # make sure that the items are read back in the same order for both datasets for _ in range(500): res2 = self.evaluate(next2()) res3 = self.evaluate(next3()) self.assertEqual(res2, res3) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testReadSnapshotParallelAfterWrite(self, compression): self.setUpTFRecord(10, 4000) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 4000) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, reader_buffer_size=10, compression=compression)) self.assertDatasetProduces(dataset, expected, assert_items_equal=True) # remove the original files and try to read the data back only from # snapshot. self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, reader_buffer_size=10, compression=compression)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) # Not testing Snappy here because Snappy reads currently require a lot of # memory. @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.times( combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP ]), combinations.combine(threads=2, size=[1, 2]) + combinations.combine(threads=8, size=[1, 4, 8])))) def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads, size): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, compression=compression, num_writer_threads=threads, writer_buffer_size=size)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from # snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testSameFingerprintWithDifferentInitializationOrder(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(0, 100) dataset2 = dataset_ops.Dataset.range(100, 200) dataset3 = dataset_ops.Dataset.range(200, 300) dataset = dataset1.concatenate(dataset2).concatenate(dataset3) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(300))) dataset4 = dataset_ops.Dataset.range(200, 300) dataset5 = dataset_ops.Dataset.range(100, 200) dataset6 = dataset_ops.Dataset.range(0, 100) dataset = dataset6.concatenate(dataset5).concatenate(dataset4) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(300))) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testExpiredSnapshotRewrite(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply( snapshot.legacy_snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) next1 = self.getNext(dataset1) # Don't finish reading dataset1, so it is never finalized for _ in range(500): self.evaluate(next1()) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) time.sleep(2) # Creating dataset2 after we run through dataset1 due to eager mode, where # the snapshot state is determined immediately upon dataset creation. We # only want to determine the snapshot state for dataset2 after the first # snapshot has expired. dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply( snapshot.legacy_snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) next2 = self.getNext(dataset2) for _ in range(500): self.evaluate(next2()) self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1) @combinations.generate(test_base.default_test_combinations()) def testSnapshotArgsCreateNewSnapshot(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10000)) next1 = self.getNext(dataset1) for _ in range(1000): self.evaluate(next1()) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) # Create second snapshot with a different shard_size_bytes dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset1.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=20000)) next2 = self.getNext(dataset2) for _ in range(1000): self.evaluate(next2()) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, snapshot.COMPRESSION_SNAPPY ]))) def testSpecifyShardSize(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.from_tensor_slices([1.0]) dataset = dataset.map( lambda x: gen_array_ops.broadcast_to(x, [1024, 1024])) dataset = dataset.repeat(10) dataset = dataset.apply( snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10 * 1024 * 1024, compression=compression)) next_fn = self.getNext(dataset) for _ in range(10): self.evaluate(next_fn()) num_files = 1 if compression == snapshot.COMPRESSION_NONE: num_files = 3 self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files) @combinations.generate(test_base.default_test_combinations()) def testAdditionalOperationsAfterReadBack(self): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir)) self.assertDatasetProduces(dataset2, expected) expected_after = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply(snapshot.legacy_snapshot(tmpdir)) dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000)) self.assertDatasetProduces(dataset3, expected_after)
class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testOptionsDefault(self): ds = dataset_ops.Dataset.range(0) self.assertEqual(dataset_ops.Options(), ds.options()) @combinations.generate(test_base.default_test_combinations()) def testOptionsOnce(self): options = dataset_ops.Options() ds = dataset_ops.Dataset.range(0).with_options(options).cache() self.assertEqual(options, ds.options()) @combinations.generate(test_base.default_test_combinations()) def testOptionsTwiceSame(self): options = dataset_ops.Options() options.experimental_optimization.autotune = True ds = dataset_ops.Dataset.range(0).with_options(options).with_options( options) self.assertEqual(options, ds.options()) @combinations.generate(test_base.default_test_combinations()) def testOptionsTwiceDifferent(self): options1 = dataset_ops.Options() options1.experimental_optimization.autotune = True options2 = dataset_ops.Options() options2.experimental_deterministic = False ds = dataset_ops.Dataset.range(0).with_options(options1).with_options( options2) self.assertTrue(ds.options().experimental_optimization.autotune) # Explicitly check that flag is False since assertFalse allows None self.assertIs(ds.options().experimental_deterministic, False) @combinations.generate(test_base.default_test_combinations()) def testOptionsTwiceDifferentError(self): options1 = dataset_ops.Options() options1.experimental_optimization.autotune = True options2 = dataset_ops.Options() options2.experimental_optimization.autotune = False with self.assertRaisesRegexp(ValueError, "Cannot merge incompatible values"): dataset_ops.Dataset.range(0).with_options(options1).with_options( options2) @combinations.generate(test_base.default_test_combinations()) def testOptionsMergeOptionsFromMultipleInputs(self): options1 = dataset_ops.Options() options1.experimental_optimization.autotune = True options2 = dataset_ops.Options() options2.experimental_deterministic = True ds = dataset_ops.Dataset.zip( (dataset_ops.Dataset.range(0).with_options(options1), dataset_ops.Dataset.range(0).with_options(options2))) self.assertTrue(ds.options().experimental_optimization.autotune) self.assertTrue(ds.options().experimental_deterministic) @combinations.generate(test_base.default_test_combinations()) def testOptionsHaveDefaults(self): options1 = dataset_ops.Options() options2 = dataset_ops.Options() self.assertIsNot(options1.experimental_optimization, options2.experimental_optimization) self.assertIsNot(options1.experimental_stats, options2.experimental_stats) self.assertIsNot(options1.experimental_threading, options2.experimental_threading) self.assertEqual(options1.experimental_optimization, optimization_options.OptimizationOptions()) self.assertEqual(options1.experimental_stats, stats_options.StatsOptions()) self.assertEqual(options1.experimental_threading, threading_options.ThreadingOptions())
class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, parameterized.TestCase): def setUp(self): super(SnapshotDatasetTest, self).setUp() tmpdir = self.get_temp_dir() tmpdir = os.path.join(tmpdir, "snapshot") os.mkdir(tmpdir) self._snapshot_dir = tmpdir def tearDown(self): super(SnapshotDatasetTest, self).tearDown() shutil.rmtree(self._snapshot_dir) def createTFRecords(self, num_files=10, num_records=100): self._num_files = num_files self._num_records = num_records self._test_filenames = self._createFiles() def removeTFRecords(self): for filename in self._test_filenames: os.remove(filename) self._test_filenames = [] self._num_files = None self._num_records = None def assertDatasetProducesSet(self, dataset, expected): actual = [] next_fn = self.getNext(dataset) for _ in range(len(expected)): elem = self.evaluate(next_fn()) actual.append(elem) self.assertCountEqual(actual, expected) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_fn()) def assertSnapshotDirectoryContains(self, directory, num_fingerprints, num_runs_per_fingerprint, num_snapshot_shards_per_run): dirlist_raw = os.listdir(directory) dirlist = [] # Ignore the graphdef pbtxts we write for debugging purposes. for i in range(len(dirlist_raw)): if not dirlist_raw[i].endswith("-graph.pbtxt"): dirlist.append(dirlist_raw[i]) self.assertLen(dirlist, num_fingerprints) for i in range(num_fingerprints): fingerprint_dir = os.path.join(directory, dirlist[i]) fingerprint_dir_list = sorted(os.listdir(fingerprint_dir)) self.assertLen(fingerprint_dir_list, num_runs_per_fingerprint + 1) self.assertEqual(fingerprint_dir_list[num_runs_per_fingerprint], "snapshot.metadata") for j in range(num_runs_per_fingerprint): run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j]) run_dirlist = sorted(os.listdir(run_dir)) self.assertLen(run_dirlist, num_snapshot_shards_per_run) file_counter = 0 for filename in run_dirlist: self.assertEqual(filename, "%08d.shard" % file_counter) file_counter += 1 @combinations.generate(test_base.default_test_combinations()) def testCreateSnapshotDataset(self): dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]) dataset.apply(snapshot.snapshot(self._snapshot_dir)) @combinations.generate(test_base.default_test_combinations()) def testReadSnapshotDatasetDefault(self): self.createTFRecords() filenames = self._test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 100) ] dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset, expected) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset2, expected) @combinations.generate(test_base.default_test_combinations()) def testReadSnapshotDatasetAutoWriteSnappyRead(self): self.createTFRecords() filenames = self._test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 100) ] dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot(self._snapshot_dir, compression="AUTO")) self.assertDatasetProduces(dataset, expected) self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(self._snapshot_dir, compression="SNAPPY")) self.assertDatasetProduces(dataset2, expected) @combinations.generate(test_base.default_test_combinations()) def testReadSnapshotDatasetCustomShardFn(self): self.createTFRecords() filenames = self._test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 100) ] dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: np.int64(0))) self.assertDatasetProduces(dataset, expected) self.assertSnapshotDirectoryContains(self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=1) self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: 0)) self.assertDatasetProduces(dataset2, expected) @combinations.generate(test_base.default_test_combinations()) def testReadSnapshotDatasetCustomReaderFn(self): self.createTFRecords() filenames = self._test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 100) ] dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot( self._snapshot_dir, reader_func=( lambda ds: ds.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=4, num_parallel_calls=4)))) self.assertDatasetProduces(dataset, expected) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot( self._snapshot_dir, reader_func=( lambda ds: ds.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=4, num_parallel_calls=4)))) self.assertDatasetProducesSet(dataset2, expected) @combinations.generate(test_base.default_test_combinations()) def testSnapshotDatasetInvalidShardFn(self): dataset = dataset_ops.Dataset.range(1000) with self.assertRaises(TypeError): dataset = dataset.apply( snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: "invalid_fn")) next_fn = self.getNext(dataset) self.evaluate(next_fn()) @combinations.generate(test_base.default_test_combinations()) def testSnapshotDatasetInvalidReaderFn(self): dataset = dataset_ops.Dataset.range(1000) with self.assertRaises(TypeError): dataset = dataset.apply( snapshot.snapshot(self._snapshot_dir, reader_func=lambda x: x + 1)) next_fn = self.getNext(dataset) self.evaluate(next_fn()) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotDatasetSimple(self): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset, list(range(1000))) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotDatasetMultipleFingerprints(self): dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset1, list(range(1000))) dataset2 = dataset_ops.Dataset.range(2000) dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset2, list(range(2000))) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=2, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotDatasetSameFingerprintMultipleCompleteRuns(self): dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset1, list(range(1000))) dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset2, list(range(1000))) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotDatasetSameFingerprintIncompleteRunRestart(self): dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir)) next1 = self.getNext(dataset1) for i in range(500): self.assertEqual(i, self.evaluate(next1())) dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir)) next2 = self.getNext(dataset2) for i in range(500): self.assertEqual(i, self.evaluate(next2())) for i in range(500, 1000): self.assertEqual(i, self.evaluate(next1())) self.assertEqual(i, self.evaluate(next2())) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=2, num_snapshot_shards_per_run=multiprocessing.cpu_count()) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotCustomShardFunction(self): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.enumerate() dataset = dataset.apply( snapshot.snapshot(self._snapshot_dir, shard_func=lambda i, _: i % 2)) dataset = dataset.map(lambda _, elem: elem) self.assertDatasetProduces(dataset, list(range(1000))) self.assertSnapshotDirectoryContains(self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=2) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotDatasetWithTuples(self): dataset1 = dataset_ops.Dataset.range(0, 1000) dataset2 = dataset_ops.Dataset.range(1000, 2000) dataset3 = dataset_ops.Dataset.range(2000, 3000) dataset4 = dataset_ops.Dataset.range(3000, 4000) dataset = dataset_ops.Dataset.zip( (dataset1, dataset2, dataset3, dataset4)) dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) expected = list( zip(range(0, 1000), range(1000, 2000), range(2000, 3000), range(3000, 4000))) self.assertDatasetProduces(dataset, expected) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotShuffleSameFingerprint(self): def make_dataset(): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.shuffle(1000) dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) return dataset dataset1 = make_dataset() self.assertDatasetProducesSet(dataset1, list(range(1000))) dataset2 = make_dataset() self.assertDatasetProducesSet(dataset2, list(range(1000))) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) @combinations.generate(test_base.default_test_combinations()) def testReadUsingFlatMap(self): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset, list(range(1000))) flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map( lambda x: x) self.assertDatasetProduces(flat_map, list(range(1000))) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) @combinations.generate(test_base.default_test_combinations()) def testReadOptimizableUsingFlatMap(self): if context.context().use_tfrt: self.skipTest("b/177260096: Flaky test.") dataset = dataset_ops.Dataset.range(100) # Will be optimized into ShuffleAndRepeat. dataset = dataset.shuffle(10) dataset = dataset.repeat(2) dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProducesSet(dataset, 2 * list(range(100))) flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map( lambda x: x) self.assertDatasetProducesSet(flat_map, 2 * list(range(100))) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count())
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.default_test_combinations(), data_service_test_base.all_cluster_configurations()) ) def testDistributeBasic(self, work_dir, fault_tolerant_mode): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(compression=[None, "AUTO"]))) def testDistributeCompression(self, compression): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, compression=compression) self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testDistributeInvalidCompression(self): cluster = data_service_test_base.TestCluster(num_workers=1) with self.assertRaisesRegex(ValueError, "Invalid compression argument"): self.make_distributed_range_dataset(10, cluster, compression="foo") @combinations.generate(test_base.eager_only_combinations()) def testDistributeSparse(self): cluster = data_service_test_base.TestCluster(num_workers=1) element = sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) ds = dataset_ops.Dataset.from_tensors(element) ds = self.make_distributed_dataset(ds, cluster) results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds] self.assertAllEqual(results, [[0]]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeRagged(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8]) ds = ds.map(math_ops.range) ds = ds.apply(batching.dense_to_ragged_batch(2)) ds = self.make_distributed_dataset(ds, cluster) results = [elem.to_tensor() for elem in ds] self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]]) self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]]) self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]]) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( init_source=["textfile", "keyvaluetensor", "dataset"]))) def testDistributeLookupTable(self, init_source): cluster = data_service_test_base.TestCluster(num_workers=1) initializer = self.lookupTableInitializer(init_source, [10, 11]) table = lookup_ops.StaticHashTable(initializer, -1) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.evaluate(lookup_ops.tables_initializer()) self.assertDatasetProduces(ds, [10, 11, -1], requires_initialization=True) @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(value_rank=[0, 1]))) def testDistributeMutableHashTable(self, value_rank): def value(v): for _ in range(value_rank): v = [v, v] return v v1 = value(10) v2 = value(11) default_value = value(-1) cluster = data_service_test_base.TestCluster(num_workers=1) table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64, default_value) self.evaluate(table.insert([0, 1], [v1, v2])) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [v1, v2, default_value], requires_initialization=True) @combinations.generate(test_base.default_test_combinations()) def testDifferentShuffleOrders(self): random_seed.set_random_seed(None) num_elements = 100 cluster = data_service_test_base.TestCluster(num_workers=2) ds = dataset_ops.Dataset.range(num_elements) ds = ds.shuffle(num_elements) ds = self.make_distributed_dataset(ds, cluster) output = self.getDatasetOutput(ds) # The output will be two sequences of range(num_elements) # non-deterministically interleaved together. If the orders of the elements # were the same, first_order and second_order computed below will be equal. first_order = {} second_order = {} for element in output: if element in first_order: second_order[element] = len(second_order) else: first_order[element] = len(first_order) self.assertNotEqual(first_order, second_order) @combinations.generate(test_base.default_test_combinations()) def testMultipleEpochs(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) for _ in range(10): self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testRepeatedDataset(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster) ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_repetitions * list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testConcurrentEpoch(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 num_datasets = 3 get_nexts = [] results = [] for _ in range(num_datasets): ds = self.make_distributed_range_dataset(num_elements, cluster) get_nexts.append(self.getNext(ds)) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = self.evaluate(get_nexts[dataset_ind]()) results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result) @combinations.generate(test_base.default_test_combinations()) def testMultiWorker(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testMaxOutstandingRequests(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, max_outstanding_requests=1) self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testInsideFunction(self): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) num_elements = 10 @def_function.function def f(): ds = self.make_distributed_range_dataset(num_elements, cluster) result = tensor_array_ops.TensorArray(dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in ds: result = result.write(i, elem) i += 1 return result.stack() result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) @combinations.generate(test_base.default_test_combinations()) def testSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 1000 def make_ds(): return dataset_ops.Dataset.range(num_elements).shuffle( num_elements) ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") get_next_1 = self.getNext(ds1) get_next_2 = self.getNext(ds2) results = [] for _ in range(num_elements // 5): results.append(self.evaluate(get_next_1())) results.append(self.evaluate(get_next_2())) results += self.getIteratorOutput(get_next_1) results += self.getIteratorOutput(get_next_2) self.assertCountEqual(list(range(num_elements)), results) @combinations.generate(test_base.default_test_combinations()) def testDifferentJobNames(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name1") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) # iteration 2 self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds1, []) @combinations.generate(test_base.default_test_combinations()) def testSharedJobNameRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 num_repetitions = 3 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds1 = ds1.repeat(num_repetitions) ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] get_next_1 = self.getNext(ds1) get_next_2 = self.getNext(ds2) for _ in range((num_elements * num_repetitions) // 5): results.append(self.evaluate(get_next_1())) for _ in range((num_elements * num_repetitions) // 5): results.append(self.evaluate(get_next_2())) results += self.getIteratorOutput(get_next_1) results += self.getIteratorOutput(get_next_2) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(job_name=[None, "test"]))) def testGcUnusedJob(self, job_name): cluster = data_service_test_base.TestCluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster, job_name=job_name) it = iter(ds) self.assertEqual(next(it).numpy(), 0) self.assertEqual(cluster.workers[0].num_tasks(), 1) del it while cluster.workers[0].num_tasks() > 0: time.sleep(0.1) @combinations.generate(test_base.eager_only_combinations()) def testDontGcUsedJob(self): cluster = data_service_test_base.TestCluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 10 it1 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test1")) it2 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) self.assertEqual(cluster.workers[0].num_tasks(), 2) del it1 del it2 # Check that only the first job is gced. The second job will not be gced # because there is still an outstanding iterator for it. while cluster.workers[0].num_tasks() > 1: time.sleep(0.1) self.assertEqual(cluster.workers[0].num_tasks(), 1) @combinations.generate(test_base.default_test_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) cluster = data_service_test_base.TestCluster(num_workers=1) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10) opts = dataset_ops.Options() opts.experimental_deterministic = False ds = ds.with_options(opts) ds = self.make_distributed_dataset(ds, cluster) return ds self.checkDeterminism(dataset_fn=dataset_fn, expect_determinism=False, expected_elements=elements) def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = dataset_ops.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) cluster = data_service_test_base.TestCluster(num_workers=3) ds = self.make_distributed_dataset(ds, cluster) self.getDatasetOutput(ds) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(external_state_policy=[ distribute_options.ExternalStatePolicy.IGNORE, distribute_options.ExternalStatePolicy.WARN ]))) def testStatefulNoError(self, external_state_policy): self.run_stateful(external_state_policy) @combinations.generate(test_base.default_test_combinations()) def testStatefulError(self): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(distribute_options.ExternalStatePolicy.FAIL) @combinations.generate(test_base.default_test_combinations()) def testDistributeFromInterleave(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(2) def interleave_fn(_): dataset = dataset_ops.Dataset.range(2) self.make_distributed_dataset(dataset, cluster) return dataset ds = ds.interleave(interleave_fn, cycle_length=2) self.assertDatasetProduces(ds, [0, 0, 1, 1]) @combinations.generate(test_base.default_test_combinations()) def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "service must be a string"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=1)) @combinations.generate(test_base.default_test_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "service must not be empty"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="")) @combinations.generate(test_base.default_test_combinations()) def testDistributeExplicitProtocol(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(10) ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="grpc://" + cluster.dispatcher_address())) self.assertDatasetProduces(ds, list(range(10))) @combinations.generate(test_base.default_test_combinations()) def testDistributeInvalidProtocol(self): cluster = data_service_test_base.TestCluster(num_workers=1) ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex( errors.NotFoundError, "No credentials factory has been registered for protocol grp"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="grp://" + cluster.dispatcher_address())) self.getDatasetOutput(ds) @combinations.generate(test_base.eager_only_combinations()) def testDistributeInvalidProcessingMode(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "invalid is not a valid processing mode"): ds = ds.apply( data_service_ops.distribute(processing_mode="invalid", service="grpc://localhost:5000")) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentProcessingModesDatasets(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs") ds = dataset_ops.Dataset.zip((ds1, ds2)) self.assertDatasetProduces(ds, list( zip(range(num_elements), range(num_elements))), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testZipDifferentProcessingModesDatasetsSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch", job_name="job_name") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs", job_name="job_name") ds = dataset_ops.Dataset.zip((ds1, ds2)) with self.assertRaisesRegex(errors.FailedPreconditionError, "but there is already an existing job"): self.getDatasetOutput(ds) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements))) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdMultipleComponents(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"]) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdWrongElementSpec(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, wrong_spec) with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testFromDatasetIdNotRegistered(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, element_spec) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 cluster = data_service_test_base.TestCluster(num_workers=1) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = self.make_distributed_dataset(ds, cluster) ds = ds.prefetch(1) get_next = self.getNext(ds) self.assertEqual(0, self.evaluate(get_next())) # Without properly implemented cancellation, we will hang here while trying # to garbage collect the dataset iterator. @combinations.generate(test_base.default_test_combinations()) def testRegisterEquivalentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1) id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2) self.assertEqual(self.evaluate(id_1), self.evaluate(id_2)) @combinations.generate(test_base.default_test_combinations()) def testRegisterDifferentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(20) cluster = data_service_test_base.TestCluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1) id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2) self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2)) @combinations.generate(test_base.default_test_combinations()) def testTwoLevelDistribute(self): cluster_1_size = 3 cluster_1 = data_service_test_base.TestCluster( num_workers=cluster_1_size) cluster_2 = data_service_test_base.TestCluster(num_workers=1) num_sizes = 10 size_repeats = 5 strings = ["a" * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(strings) ds = ds.shuffle(len(strings)) ds = self.make_distributed_dataset(ds, cluster_1) # Large enough so that all strings of the same size are windowed together. window_size = cluster_1_size * size_repeats batch_size = size_repeats def key_func(x): return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64) ds = ds.apply( grouping.group_by_window( key_func=key_func, reduce_func=lambda _, x: x.batch(batch_size), window_size=window_size)) ds = self.make_distributed_dataset(ds, cluster_2) get_next = self.getNext(ds) for _ in range(num_sizes): element = self.evaluate(get_next()) for _ in range(1, cluster_1_size): self.assertAllEqual(self.evaluate(get_next()), element) self.assertEmpty(self.getIteratorOutput(get_next)) @combinations.generate( combinations.times(test_base.default_test_combinations())) def testDistributeLargeGraph(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [tensor])
class DenseToSparseBatchTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testBasic(self): components = np.random.randint(12, size=(100, )).astype(np.int32) dataset = dataset_ops.Dataset.from_tensor_slices(components).map( lambda x: array_ops.fill([x], x)).apply( batching.dense_to_sparse_batch(4, [12])) get_next = self.getNext(dataset) for start in range(0, len(components), 4): results = self.evaluate(get_next()) self.assertAllEqual( [[i, j] for i, c in enumerate(components[start:start + 4]) for j in range(c)], results.indices) self.assertAllEqual( [c for c in components[start:start + 4] for _ in range(c)], results.values) self.assertAllEqual([min(4, len(components) - start), 12], results.dense_shape) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testWithUnknownShape(self): components = np.random.randint(5, size=(40, )).astype(np.int32) dataset = dataset_ops.Dataset.from_tensor_slices(components).map( lambda x: array_ops.fill([x, x], x)).apply( batching.dense_to_sparse_batch(4, [5, None])) get_next = self.getNext(dataset) for start in range(0, len(components), 4): results = self.evaluate(get_next()) self.assertAllEqual( [[i, j, z] for i, c in enumerate(components[start:start + 4]) for j in range(c) for z in range(c)], results.indices) self.assertAllEqual([ c for c in components[start:start + 4] for _ in range(c) for _ in range(c) ], results.values) self.assertAllEqual([ min(4, len(components) - start), 5, np.max(components[start:start + 4]) ], results.dense_shape) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) @combinations.generate(test_base.default_test_combinations()) def testWithInvalidShape(self): input_tensor = array_ops.constant([[1]]) with self.assertRaisesRegex(ValueError, "Dimension -2 must be >= 0"): dataset_ops.Dataset.from_tensors(input_tensor).apply( batching.dense_to_sparse_batch(4, [-2])) @combinations.generate(test_base.default_test_combinations()) def testShapeErrors(self): def dataset_fn(input_tensor): return dataset_ops.Dataset.from_tensors(input_tensor).apply( batching.dense_to_sparse_batch(4, [12])) # Initialize with an input tensor of incompatible rank. get_next = self.getNext(dataset_fn([[1]])) with self.assertRaisesRegex(errors.InvalidArgumentError, "incompatible with the row shape"): self.evaluate(get_next()) # Initialize with an input tensor that is larger than `row_shape`. get_next = self.getNext(dataset_fn(np.int32(range(13)))) with self.assertRaisesRegex(errors.DataLossError, "larger than the row shape"): self.evaluate(get_next())
class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, parameterized.TestCase): def setUp(self): super(SnapshotDatasetTest, self).setUp() self.removeTFRecords() def removeTFRecords(self): for filename in self.test_filenames: os.remove(filename) self.test_filenames = [] def setUpTFRecord(self, num_files=10, num_records=10): self._num_files = num_files self._num_records = num_records self.test_filenames = self._createFiles() def makeSnapshotDirectory(self): tmpdir = self.get_temp_dir() tmpdir = os.path.join(tmpdir, "snapshot") os.mkdir(tmpdir) return tmpdir def assertSnapshotDirectoryContains(self, directory, num_fingerprints, num_runs_per_fp, num_snapshot_files): dirlist = os.listdir(directory) self.assertLen(dirlist, num_fingerprints) for i in range(num_fingerprints): fingerprint_dir = os.path.join(directory, dirlist[i]) fingerprint_dir_list = sorted(os.listdir(fingerprint_dir)) self.assertLen(fingerprint_dir_list, num_runs_per_fp + 1) self.assertEqual(fingerprint_dir_list[num_runs_per_fp], "snapshot.metadata") for j in range(num_runs_per_fp): run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j]) run_dirlist = sorted(os.listdir(run_dir)) self.assertLen(run_dirlist, num_snapshot_files) file_counter = 0 for filename in run_dirlist: self.assertEqual(filename, "%08d.snapshot" % file_counter) file_counter += 1 @combinations.generate(test_base.default_test_combinations()) def testWriteDifferentPipelinesInOneDirectory(self): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1000))) dataset = dataset_ops.Dataset.range(1001) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1001))) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotMultipleSimultaneous(self): tmpdir = self.makeSnapshotDirectory() dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.snapshot(tmpdir)) next1 = self.getNext(dataset1) dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) next2 = self.getNext(dataset2) for _ in range(1000): self.evaluate(next1()) self.evaluate(next2()) # we check that only one copy of the metadata has been written, and the # one that lost the race would be in passthrough mode. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testGetNextCreatesDir(self): tmpdir = self.makeSnapshotDirectory() # We create two iterators but call getNext on only one. dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.snapshot(tmpdir)) next1 = self.getNext(dataset1) dataset2 = dataset_ops.Dataset.range(1001) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) _ = self.getNext(dataset2) for _ in range(1000): self.evaluate(next1()) # We check that only one directory is created. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP ]))) def testWriteSnapshotSimpleSuccessful(self, compression): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply( snapshot.snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset, list(range(1000))) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testWriteSnapshotRepeatAfterwards(self): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(snapshot.snapshot(tmpdir)) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP ]))) def testReadSnapshotBackAfterWrite(self, compression): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset2, expected) @combinations.generate(test_base.default_test_combinations()) def testReadShuffledSnapshotAfterWrite(self): self.setUpTFRecord(num_files=10, num_records=50) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 50) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, shuffle_on_read=True)) next2 = self.getNext(dataset2) res1 = self.evaluate(next2()) res2 = self.evaluate(next2()) res3 = self.evaluate(next2()) res4 = self.evaluate(next2()) res5 = self.evaluate(next2()) # make sure that we don't read the file back in the same order. self.assertNotEqual([res1, res2, res3, res4, res5], expected[0:5]) # make sure all the elements are still there dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply( snapshot.snapshot(tmpdir, shuffle_on_read=True)) self.assertDatasetProduces(dataset3, expected, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testReadSnapshotParallelAfterWrite(self): self.setUpTFRecord(10, 4000) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 4000) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot(tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, reader_buffer_size=10)) self.assertDatasetProduces(dataset, expected, assert_items_equal=True) # remove the original files and try to read the data back only from # snapshot. self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, reader_buffer_size=10)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.times( combinations.combine(compression=[ snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP ]), combinations.combine(threads=2, size=[1, 2]) + combinations.combine(threads=8, size=[1, 4, 8])))) def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads, size): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot(tmpdir, compression=compression, num_writer_threads=threads, writer_buffer_size=size)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from # snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testSameFingerprintWithDifferentInitializationOrder(self): tmpdir = self.makeSnapshotDirectory() dataset1 = dataset_ops.Dataset.range(0, 100) dataset2 = dataset_ops.Dataset.range(100, 200) dataset3 = dataset_ops.Dataset.range(200, 300) dataset = dataset1.concatenate(dataset2).concatenate(dataset3) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(300))) dataset4 = dataset_ops.Dataset.range(200, 300) dataset5 = dataset_ops.Dataset.range(100, 200) dataset6 = dataset_ops.Dataset.range(0, 100) dataset = dataset6.concatenate(dataset5).concatenate(dataset4) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(300))) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) @combinations.generate(test_base.default_test_combinations()) def testExpiredSnapshotRewrite(self): tmpdir = self.makeSnapshotDirectory() dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply( snapshot.snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) next1 = self.getNext(dataset1) # Don't finish reading dataset1, so it is never finalized for _ in range(500): self.evaluate(next1()) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) time.sleep(2) # Creating dataset2 after we run through dataset1 due to eager mode, where # the snapshot state is determined immediately upon dataset creation. We # only want to determine the snapshot state for dataset2 after the first # snapshot has expired. dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) next2 = self.getNext(dataset2) for _ in range(500): self.evaluate(next2()) self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1) @combinations.generate(test_base.default_test_combinations()) def testSpecifyShardSize(self): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.from_tensor_slices([1.0]) dataset = dataset.map( lambda x: gen_array_ops.broadcast_to(x, [1024, 1024])) dataset = dataset.repeat(10) dataset = dataset.apply( snapshot.snapshot(tmpdir, shard_size_bytes=10 * 1024 * 1024)) next_fn = self.getNext(dataset) for _ in range(10): self.evaluate(next_fn()) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 4) @combinations.generate(test_base.default_test_combinations()) def testAdditionalOperationsAfterReadBack(self): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset2, expected) expected_after = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply(snapshot.snapshot(tmpdir)) dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000)) self.assertDatasetProduces(dataset3, expected_after)
class AutoShardTest(data_service_test_base.TestBase, tf_record_test_base.TFRecordTestBase, parameterized.TestCase): """Tests auto-sharding datasets with tf.data service.""" def setUp(self): super(AutoShardTest, self).setUp() self._num_files = 10 self._num_records = 10 self._filenames = self._createFiles() @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.DATA, ShardingPolicy.FILE_OR_DATA ]))) def testRangeDataset_AutoShard(self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) self.assertDatasetProduces(dataset, [1, 6, 11, 16]) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_FileShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE) with self.assertRaisesRegex(errors.NotFoundError, "Found an unshardable source dataset"): self.getDatasetOutput(dataset) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(worker_index=[distribute.SHARD_HINT, 0, 5]))) def testRangeDataset_ShardHint(self, worker_index): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) # With HINT sharding, `num_shards` should be `SHARD_HINT`; `index` can be # any value. dataset = dataset.shard(num_shards=distribute.SHARD_HINT, index=worker_index) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) self.assertDatasetProduces(dataset, [1, 6, 11, 16]) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_InvalidWorkerIndexUsingShardHint(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) # With HINT sharding, `SHARD_HINT` should be passed to `num_shards`, not # `index`. with self.assertRaisesRegex( errors.InvalidArgumentError, r"Index must be between 0 and 4 \(currently index = -1\)."): dataset = dataset.shard(num_shards=5, index=distribute.SHARD_HINT) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_NoShardHint(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) # No SHARD_HINT is provided. The given sharding arguments will be used. dataset = dataset.shard(num_shards=1, index=0) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) self.assertDatasetProduces(dataset, list(range(20))) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.OFF, ShardingPolicy.FILE_OR_DATA ]))) def testRangeDataset_ShardHintUsedInWrongShardingPolicy( self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) with self.assertRaisesRegex( errors.FailedPreconditionError, "tf.data service with " "`tf.data.experimental.service.ShardingPolicy.HINT` processing mode." ): self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_NoShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.OFF, target_workers="LOCAL") self.assertDatasetProduces(dataset, list(range(20))) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_OneWorker(self): """Makes sure shards from all workers form the complete dataset.""" cluster = _make_service_cluster(num_workers=1, local_shard_index=0) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) self.assertDatasetProduces(dataset, list(range(20))) @combinations.generate(test_base.default_test_combinations()) def testRangeDataset_ReadFromAllWorkers(self): """Makes sure shards from all workers form the complete dataset.""" cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA, target_workers="ANY") with self.assertRaisesRegex( errors.InvalidArgumentError, "Static sharding requires reading from local workers"): self.getDatasetOutput(dataset) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.FILE_OR_DATA, ShardingPolicy.FILE ]))) def testTFRecordDataset_AutoShard(self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy, target_workers="LOCAL") expected = [ b"Record %d of file %d" % (record, file) for file in (3, 8) for record in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.FILE_OR_DATA, ShardingPolicy.FILE ]))) def testTFRecordDataset_ShuffleFileList(self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=True) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) expected = [ b"Record %d of file %d" % (record, file) for file in (3, 8) for record in range(0, 10) ] self.assertDatasetProduces(dataset, expected, assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_DataShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.DATA) expected = [ b"Record %d of file %d" % (record, file) for file in range(0, 10) for record in (3, 8) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_HintDataShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) expected = [ b"Record %d of file %d" % (record, file) for file in range(0, 10) for record in (3, 8) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_HintFileShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) expected = [ b"Record %d of file %d" % (record, file) for file in (3, 8) for record in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_NoShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.OFF, target_workers="LOCAL") expected = [ b"Record %d of file %d" % (record, file) for file in range(0, 10) for record in range(0, 10) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_ReadFromAllWorkers(self): """Makes sure shards from all workers form the complete dataset.""" cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA, target_workers="ANY") with self.assertRaisesRegex( errors.InvalidArgumentError, "Static sharding requires reading from local workers"): self.getDatasetOutput(dataset) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.FILE_OR_DATA, ShardingPolicy.FILE ]))) def testTFRecordDataset_FewerFilesThanWorkers(self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames[:4], shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) with self.assertRaisesRegex( errors.InvalidArgumentError, "not enough for the required 5 shards/workers."): self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_FewerFilesThanWorkers_HintShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames[:4], shuffle=False) dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT) with self.assertRaisesRegex( errors.InvalidArgumentError, "not enough for the required 5 shards/workers."): self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testTFRecordDataset_FewerFilesThanWorkers_DataShard(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames[:4], shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.DATA) expected = [ b"Record %d of file %d" % (record, file) for file in range(0, 4) for record in (3, 8) ] self.assertDatasetProduces(dataset, expected, assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=[ ShardingPolicy.FILE_OR_DATA, ShardingPolicy.DATA ]))) def testBatchDataset(self, sharding_policy): cluster = _make_service_cluster(num_workers=5, local_shard_index=1) dataset = dataset_ops.Dataset.range(20) dataset = dataset.batch(batch_size=3, drop_remainder=False) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) self.assertDatasetProduces(dataset, [[3, 4, 5], [18, 19]]) @combinations.generate(test_base.default_test_combinations()) def testInterleaveDataset(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.interleave(readers.TFRecordDataset, cycle_length=10, num_parallel_calls=dataset_ops.AUTOTUNE) dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) expected = [ b"Record %d of file %d" % (record, file) for record in range(0, 10) for file in (3, 8) ] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testZipDataset(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset1 = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset1 = dataset1.interleave(readers.TFRecordDataset, cycle_length=10, num_parallel_calls=dataset_ops.AUTOTUNE) dataset2 = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset2 = dataset2.interleave(readers.TFRecordDataset, cycle_length=10, num_parallel_calls=dataset_ops.AUTOTUNE) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) expected = [(b"Record %d of file %d" % (record, file), b"Record %d of file %d" % (record, file)) for record in range(0, 10) for file in (3, 8)] self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testConcatenateDataset(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset1 = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset1 = dataset1.interleave(readers.TFRecordDataset, cycle_length=10, num_parallel_calls=dataset_ops.AUTOTUNE) dataset2 = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset2 = dataset2.interleave(readers.TFRecordDataset, cycle_length=10, num_parallel_calls=dataset_ops.AUTOTUNE) dataset = dataset1.concatenate(dataset2) dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) expected = [ b"Record %d of file %d" % (record, file) for record in range(0, 10) for file in (3, 8) ] expected += expected self.assertDatasetProduces(dataset, expected) @combinations.generate(test_base.default_test_combinations()) def testEmptyDataset(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.range(0) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) self.assertDatasetProduces(dataset, []) @combinations.generate(test_base.default_test_combinations()) def testAnonymousPorts(self): cluster = _make_service_cluster( num_workers=5, local_shard_index=3, worker_addresses=["localhost:%port%" for _ in range(5)]) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) self.assertDatasetProduces(dataset, [3, 8, 13, 18]) @combinations.generate(test_base.default_test_combinations()) def testNamedPorts(self): cluster = _make_service_cluster( num_workers=5, local_shard_index=3, worker_addresses=["localhost:%port_worker%" for _ in range(5)]) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) self.assertDatasetProduces(dataset, [3, 8, 13, 18]) @combinations.generate(test_base.default_test_combinations()) def testInvalidPorts(self): with self.assertRaisesRegex(RuntimeError, "The worker's address is not configured"): _ = _make_service_cluster( num_workers=5, local_shard_index=0, worker_addresses=["localhost:worker" for _ in range(5)]) @combinations.generate(test_base.default_test_combinations()) def testEmptyWorkerList(self): cluster = _make_service_cluster(num_workers=5, local_shard_index=1, worker_addresses=[]) dataset = dataset_ops.Dataset.range(20) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) with self.assertRaisesRegex(errors.NotFoundError, "Worker .* is not in the workers list."): self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testWorkerNotFound(self): worker_addresses = [f"fake_worker_{i}" for i in range(5)] with self.assertRaisesRegex(RuntimeError, "The worker's address is not configured"): _ = _make_service_cluster(num_workers=5, local_shard_index=0, worker_addresses=worker_addresses) @combinations.generate(test_base.default_test_combinations()) def testMoreWorkersThanConfigured(self): worker_addresses = ["localhost:%port%"] with self.assertRaisesRegex( RuntimeError, "other workers are already running at the configured host"): _ = _make_service_cluster(num_workers=5, local_shard_index=1, worker_addresses=worker_addresses) @combinations.generate(test_base.default_test_combinations()) def testNoLocalWorkers(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=0, num_remote_workers=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE_OR_DATA) with self.assertRaisesRegex( errors.InvalidArgumentError, "Local reads or static sharding require local tf.data workers" ): self.getDatasetOutput(dataset) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(sharding_policy=list(ShardingPolicy)))) def testEnumerateShardingPolicies(self, sharding_policy): """Verifies tf.data service handles every sharding policy with no errors.""" cluster = _make_service_cluster(num_workers=5, local_shard_index=3) dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = self.make_distributed_dataset( dataset, cluster=cluster, processing_mode=sharding_policy) self.getDatasetOutput(dataset)
class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.graph_only_combinations()) def testNoGradients(self): component = constant_op.constant([1.]) side = constant_op.constant(0.) add = lambda x: x + side dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) value = dataset_ops.make_one_shot_iterator(dataset).get_next() self.assertIsNone(gradients_impl.gradients(value, component)[0]) self.assertIsNone(gradients_impl.gradients(value, side)[0]) self.assertIsNone( gradients_impl.gradients(value, [component, side])[0]) @combinations.generate(test_base.graph_only_combinations()) def testCapturingStateInOneShotRaisesException(self): var = variables.Variable(37.0, name="myvar") dataset = (dataset_ops.Dataset.from_tensor_slices( [0.0, 1.0, 2.0]).map(lambda x: x + var)) with self.assertRaisesRegex( ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support " "datasets that capture stateful objects.+myvar"): dataset_ops.make_one_shot_iterator(dataset) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIterator(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(components).map( _map_fn).repeat(14)) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorCaptureByValue(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) tensor_components = tuple( [ops.convert_to_tensor(c) for c in components]) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(tensor_components).map( _map_fn).repeat(14)) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.default_test_combinations()) def testOneShotIteratorInsideContainer(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def within_container(): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square( z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(components).map( _map_fn).repeat(14)) return iterator.get_next() server = server_lib.Server.create_local_server() # Create two iterators within unique containers, and run them to # make sure that the resources aren't shared. # # The test below would fail if cname were the same across both # sessions. for j in range(2): with session.Session(server.target) as sess: cname = "iteration%d" % j with ops.container(cname): get_next = within_container() for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip( components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorNonBlocking(self): dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) iterator = dataset_ops.make_one_shot_iterator(dataset) next_element = iterator.get_next() # Create a session with a single thread to ensure that the # one-shot iterator initializer does not deadlock. config = config_pb2.ConfigProto(inter_op_parallelism_threads=1, use_per_session_threads=True) with session.Session(config=config) as sess: self.assertAllEqual([1, 4, 9], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) # Test with multiple threads invoking the one-shot iterator concurrently. with session.Session(config=config) as sess: results = [] def consumer_thread(): try: results.append(sess.run(next_element)) except errors.OutOfRangeError: results.append(None) num_threads = 8 threads = [ self.checkedThread(consumer_thread) for _ in range(num_threads) ] for t in threads: t.start() for t in threads: t.join() self.assertLen(results, num_threads) self.assertLen([None for r in results if r is None], num_threads - 1) self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) @combinations.generate(test_base.graph_only_combinations()) def testOneShotIteratorInitializerFails(self): # Define a dataset whose initialization will always fail. dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4])) iterator = dataset_ops.make_one_shot_iterator(dataset) next_element = iterator.get_next() with self.cached_session() as sess: with self.assertRaisesRegex(errors.InvalidArgumentError, ""): sess.run(next_element) # Test that subsequent attempts to use the iterator also fail. with self.assertRaisesRegex(errors.InvalidArgumentError, ""): sess.run(next_element) with self.cached_session() as sess: def consumer_thread(): with self.assertRaisesRegex(errors.InvalidArgumentError, ""): sess.run(next_element) num_threads = 8 threads = [ self.checkedThread(consumer_thread) for _ in range(num_threads) ] for t in threads: t.start() for t in threads: t.join() @combinations.generate(test_base.graph_only_combinations()) def testSimpleSharedResource(self): components = (np.array(1, dtype=np.int64), np.array([1, 2, 3], dtype=np.int64), np.array(37.0, dtype=np.float64)) server = server_lib.Server.create_local_server() # Create two non-overlapping sessions that share the same iterator # resource on the same server, and verify that an action of the # first session (initializing the iterator) is visible in the # second session. with ops.Graph().as_default(): iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensors(components).map( lambda x, y, z: (x, y, z)), shared_name="shared_iterator") init_op = iterator.initializer get_next = iterator.get_next() with session.Session(server.target) as sess: sess.run(init_op) results = sess.run(get_next) for component, result_component in zip(components, results): self.assertAllEqual(component, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Re-initialize the iterator in the first session. sess.run(init_op) with ops.Graph().as_default(): # Re-define the iterator manually, without defining any of the # functions in this graph, to ensure that we are not # accidentally redefining functions with the same names in the # new graph. iterator = iterator_ops.Iterator.from_structure( shared_name="shared_iterator", output_types=(dtypes.int64, dtypes.int64, dtypes.float64), output_shapes=([], [3], [])) get_next = iterator.get_next() with session.Session(server.target) as sess: # Use the iterator without re-initializing in the second session. results = sess.run(get_next) for component, result_component in zip(components, results): self.assertAllEqual(component, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testNotInitializedError(self): components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensors(components)) get_next = iterator.get_next() with self.cached_session() as sess: with self.assertRaisesRegex(errors.FailedPreconditionError, "iterator has not been initialized"): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testReinitializableIterator(self): dataset_3 = dataset_ops.Dataset.from_tensors( constant_op.constant([1, 2, 3])) dataset_4 = dataset_ops.Dataset.from_tensors( constant_op.constant([4, 5, 6, 7])) iterator = iterator_ops.Iterator.from_structure( dataset_ops.get_legacy_output_types(dataset_3), [None]) dataset_3_init_op = iterator.make_initializer(dataset_3) dataset_4_init_op = iterator.make_initializer(dataset_4) get_next = iterator.get_next() self.assertEqual(dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(dataset_ops.get_legacy_output_types(dataset_4), dataset_ops.get_legacy_output_types(iterator)) self.assertEqual( [None], dataset_ops.get_legacy_output_shapes(iterator).as_list()) with self.cached_session() as sess: # The iterator is initially uninitialized. with self.assertRaises(errors.FailedPreconditionError): sess.run(get_next) # Initialize with one dataset. sess.run(dataset_3_init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Initialize with a different dataset. sess.run(dataset_4_init_op) self.assertAllEqual([4, 5, 6, 7], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Reinitialize with the first dataset. sess.run(dataset_3_init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @combinations.generate(test_base.graph_only_combinations()) def testReinitializableIteratorWithFunctions(self): def g(): for i in range(10): yield i iterator = iterator_ops.Iterator.from_structure(dtypes.int64, []) next_element = iterator.get_next() with self.cached_session() as sess: dataset_1 = dataset_ops.Dataset.from_generator( g, output_types=dtypes.int64) sess.run(iterator.make_initializer(dataset_1)) for expected in range(10): self.assertEqual(expected, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) dataset_2 = dataset_ops.Dataset.from_generator( g, output_types=dtypes.int64) sess.run(iterator.make_initializer(dataset_2)) for expected in range(10): self.assertEqual(expected, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) @combinations.generate(test_base.default_test_combinations()) def testReinitializableIteratorStaticErrors(self): # Non-matching structure for types and shapes. with self.assertRaises(TypeError): iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), [None]) # Test validation of dataset argument. iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64)) # Incompatible structure. with self.assertRaises(ValueError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( ((constant_op.constant([1, 2, 3], dtype=dtypes.int64), ), (constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64), )))) # Incompatible types. with self.assertRaises(TypeError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( (constant_op.constant([1, 2, 3], dtype=dtypes.int32), constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32)))) # Incompatible shapes. iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), ([None], [])) with self.assertRaises(TypeError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( (constant_op.constant([1, 2, 3], dtype=dtypes.int64), constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64)))) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandle(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) next_element = feedable_iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(dataset_3), dataset_ops.get_structure(feedable_iterator))) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual( 10, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 1, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 20, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 2, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 30, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 3, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 40, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle}) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleFuture(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) next_element = feedable_iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(dataset_3), dataset_ops.get_structure(feedable_iterator))) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual( 10, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 1, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 20, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 2, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 30, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 3, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 40, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle}) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleReuseTensorObject(self): dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset) initializable_iterator = dataset_ops.make_initializable_iterator( dataset) structure_iterator = iterator_ops.Iterator.from_structure( dataset_ops.get_legacy_output_types(dataset)) created_ops = len(ops.get_default_graph().get_operations()) self.assertIs(one_shot_iterator.string_handle(), one_shot_iterator.string_handle()) self.assertIs(initializable_iterator.string_handle(), initializable_iterator.string_handle()) self.assertIs(structure_iterator.string_handle(), structure_iterator.string_handle()) # Assert that getting the (default) string handle creates no ops. self.assertEqual(created_ops, len(ops.get_default_graph().get_operations())) # Specifying an explicit name will create a new op. handle_with_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo", handle_with_name.op.name) self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name) handle_with_same_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo_1", handle_with_same_name.op.name) self.assertIsNot(handle_with_name, handle_with_same_name) @combinations.generate(test_base.graph_only_combinations()) def testIteratorStringHandleError(self): dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices( [1, 2, 3]).repeat()) dataset_float_vector = (dataset_ops.Dataset.from_tensors( [1.0, 2.0, 3.0])) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_int_scalar = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, []) feedable_int_vector = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, [None]) feedable_int_any = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32) with self.cached_session() as sess: handle_int_scalar = sess.run( dataset_ops.make_one_shot_iterator( dataset_int_scalar).string_handle()) handle_float_vector = sess.run( dataset_ops.make_one_shot_iterator( dataset_float_vector).string_handle()) self.assertEqual( 1, sess.run(feedable_int_scalar.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) self.assertEqual( 2, sess.run(feedable_int_any.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) with self.assertRaises(errors.InvalidArgumentError): print( sess.run(feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) with self.assertRaises(errors.InvalidArgumentError): print( sess.run( feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_float_vector})) @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpDirectSession(self): worker_config = config_pb2.ConfigProto() worker_config.device_count["CPU"] = 3 with ops.device("/job:localhost/replica:0/task:0/cpu:1"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_3_handle = iterator_3.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( h, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) return remote_iterator.get_next() with ops.device("/job:localhost/replica:0/task:0/cpu:0"): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) remote_op = functional_ops.remote_call(args=[iterator_3_handle], Tout=[dtypes.int32], f=_remote_fn, target=target_placeholder) with self.session(config=worker_config) as sess: elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [1]) # Fails when target is cpu:2 where the resource is not located. with self.assertRaises(errors.InvalidArgumentError): sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" }) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [2]) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self): s1 = server_lib.Server.create_local_server() s2 = server_lib.Server.create_local_server() s3 = server_lib.Server.create_local_server() cluster_def = cluster_pb2.ClusterDef() workers = cluster_def.job.add() workers.name = "worker" workers.tasks[0] = s1.target[len("grpc://"):] workers.tasks[1] = s2.target[len("grpc://"):] client = cluster_def.job.add() client.name = "client" client.tasks[0] = s3.target[len("grpc://"):] config = config_pb2.ConfigProto(cluster_def=cluster_def) worker_devices = [ "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2) ] itr_handles = [] for device in worker_devices: with ops.device(device): src = dataset_ops.Dataset.from_tensor_slices([device]) itr = dataset_ops.make_one_shot_iterator(src) itr_handles.append(itr.string_handle()) targets = dataset_ops.Dataset.from_tensor_slices(worker_devices) handles = dataset_ops.Dataset.from_tensor_slices(itr_handles) @function.Defun(dtypes.string) def loading_func(h): remote_itr = iterator_ops.Iterator.from_string_handle( h, dataset_ops.get_legacy_output_types(itr), dataset_ops.get_legacy_output_shapes(itr)) return remote_itr.get_next() def map_fn(target, handle): return functional_ops.remote_call(args=[handle], Tout=[dtypes.string], f=loading_func, target=target) with ops.device("/job:client"): client_dataset = dataset_ops.Dataset.zip( (targets, handles)).map(map_fn) itr = dataset_ops.make_initializable_iterator(client_dataset) n = itr.get_next() with session.Session(s3.target, config=config) as sess: sess.run(itr.initializer) expected_values = worker_devices for expected in expected_values: self.assertEqual((compat.as_bytes(expected), ), sess.run(n)) with self.assertRaises(errors.OutOfRangeError): sess.run(n) @combinations.generate(test_base.graph_only_combinations()) def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/job:localhost/replica:0/task:0/cpu:0"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_3_handle = iterator_3.string_handle() def _encode_raw(byte_array): return bytes(bytearray(byte_array)) @function.Defun(dtypes.uint8) def _remote_fn(h): handle = script_ops.py_func(_encode_raw, [h], dtypes.string) remote_iterator = iterator_ops.Iterator.from_string_handle( handle, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) return remote_iterator.get_next() with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) iterator_3_handle_uint8 = parsing_ops.decode_raw( input_bytes=iterator_3_handle, out_type=dtypes.uint8) remote_op = functional_ops.remote_call( args=[iterator_3_handle_uint8], Tout=[dtypes.int32], f=_remote_fn, target=target_placeholder) with self.cached_session() as sess: elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [1]) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [2]) elem = sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): sess.run(remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) @combinations.generate(test_base.graph_only_combinations()) def testRepeatedGetNextWarning(self): iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.range(10)) warnings.simplefilter("always") with warnings.catch_warnings(record=True) as w: for _ in range(100): iterator.get_next() self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, len(w)) for warning in w: self.assertIn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE, str(warning.message)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( expected_element_structure=tensor_spec.TensorSpec( [], dtypes.float32), expected_output_classes=ops.Tensor, expected_output_types=dtypes.float32, expected_output_shapes=[[]]))) def testTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): tf_value_fn = lambda: constant_op.constant(37.0) tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine( expected_element_structure=sparse_tensor.SparseTensorSpec( [1], dtypes.int32), expected_output_classes=sparse_tensor.SparseTensor, expected_output_types=dtypes.int32, expected_output_shapes=[[1]]))) def testSparseTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): def tf_value_fn(): return sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator)) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(expected_element_structure={ "a": tensor_spec.TensorSpec([], dtypes.float32), "b": (tensor_spec.TensorSpec([1], dtypes.string), tensor_spec.TensorSpec([], dtypes.string)) }, expected_output_classes={ "a": ops.Tensor, "b": (ops.Tensor, ops.Tensor) }, expected_output_types={ "a": dtypes.float32, "b": (dtypes.string, dtypes.string) }, expected_output_shapes={ "a": [], "b": ([1], []) }))) def testNestedTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): def tf_value_fn(): return { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) } tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator)) @combinations.generate(test_base.default_test_combinations()) def testIteratorGetNextName(self): with ops.Graph().as_default(): iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(37.0)) next_element = iterator.get_next(name="overridden_name") self.assertEqual("overridden_name", next_element.op.name) @combinations.generate( combinations.combine(tf_api_version=[1, 2], mode="eager", execution_mode=[context.ASYNC, context.SYNC])) def testIteratorEagerIteration(self, execution_mode): with context.eager_mode(), context.execution_mode(execution_mode): val = 0 dataset = dataset_ops.Dataset.range(10) iterator = iter(dataset) for foo in iterator: self.assertEqual(val, foo.numpy()) val += 1 @combinations.generate(test_base.eager_only_combinations()) def testOwnedIteratorFunction(self): queue = data_flow_ops.FIFOQueue(10, dtypes.int64) @def_function.function def fn(): dataset = dataset_ops.Dataset.range(10) iterator = iter(dataset) for _ in range(10): queue.enqueue(next(iterator)) fn() for i in range(10): self.assertEqual(queue.dequeue().numpy(), i) @combinations.generate(test_base.eager_only_combinations()) def testOwnedIteratorFunctionError(self): # In this test we verify that a function that raises an error ends up # properly deallocating the iterator resource. queue = data_flow_ops.FIFOQueue(10, dtypes.int64) queue.enqueue(0) def init_fn(n): return n def next_fn(_): ds = dataset_ops.Dataset.range(0) return next(iter(ds)) def finalize_fn(n): queue.enqueue(0) return n @def_function.function def fn(): output_signature = tensor_spec.TensorSpec((), dtypes.int64) dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn, output_signature) iterator = iter(dataset) next(iterator) with self.assertRaises(errors.OutOfRangeError): fn() self.assertEqual(queue.size().numpy(), 2) @combinations.generate(test_base.eager_only_combinations()) def testLimitedRetracing(self): trace_count = [0] @def_function.function def f(iterator): trace_count[0] += 1 counter = np.int64(0) for elem in iterator: counter += elem return counter dataset = dataset_ops.Dataset.range(5) dataset2 = dataset_ops.Dataset.range(10) for _ in range(10): self.assertEqual(self.evaluate(f(iter(dataset))), 10) self.assertEqual(self.evaluate(f(iter(dataset2))), 45) self.assertEqual(trace_count[0], 1) @combinations.generate(test_base.eager_only_combinations()) def testNestedFunctionsIteratorResource(self): @def_function.function def sum_dataset(ds): it = iter(ds) @def_function.function def next_element(it): return next(it) total = 0 for _ in range(10): total += next_element(it) return total ds = dataset_ops.Dataset.range(10) self.assertEqual(sum_dataset(ds).numpy(), 45) self.assertEqual(sum_dataset(ds).numpy(), 45) @combinations.generate(test_base.default_test_combinations()) def testNestedAutomaticControlDependencies(self): counter_var = variables.Variable(0) def map_fn(x): counter_var.assign_add(1) return x def dataset_fn(): return dataset_ops.Dataset.range(10).map(map_fn) @def_function.function def fn(): it = iter(dataset_fn()) for _ in range(10): _ = next(it) return counter_var self.evaluate(counter_var.initializer) self.assertEqual(self.evaluate(fn()), 10)
class FlatMapDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testCore(self): # Complicated way of saying range(start, start+25). def build_ds(start): def map_fn(x): return dataset_ops.Dataset.range(x, x + 5) return dataset_ops.Dataset.range(start, start + 5 * 5, 5).flat_map(map_fn) self.run_core_tests(lambda: build_ds(0), 25) @combinations.generate(test_base.default_test_combinations()) def testMapThenFlatMap(self): def build_ds(): def flat_map_fn(_): def map_fn(y): return 10 * math_ops.cast(y, dtypes.int32) return dataset_ops.Dataset.range(100).map(map_fn) return dataset_ops.Dataset.range(5).flat_map(flat_map_fn) self.run_core_tests(build_ds, 500) @combinations.generate(test_base.default_test_combinations()) def testCaptureDefunInMapFn(self): def build_ds(): def map_fn(x): @function.Defun(dtypes.int64) def defun_fn(x): return constant_op.constant(1000) + math_ops.cast( x, dtypes.int32) return dataset_ops.Dataset.from_tensor_slices([defun_fn(x)]) return dataset_ops.Dataset.range(100).flat_map(map_fn) self.run_core_tests(build_ds, 100) @combinations.generate(test_base.default_test_combinations()) def testDisallowVariableCapture(self): def build_ds(): test_var = variable_scope.get_variable(name="test_var", shape=(), use_resource=True) return dataset_ops.Dataset.range(5).flat_map( lambda _: dataset_ops.Dataset.from_tensor_slices([test_var])) self.verify_error_on_save(build_ds, 5, errors.FailedPreconditionError) @combinations.generate(test_base.default_test_combinations()) def testDisallowCapturingStatefulOps(self): def build_ds(): def flat_map_fn(_): def map_fn(x): return random_ops.random_uniform( (), 0, 10, dtype=dtypes.int32) * math_ops.cast( x, dtypes.int32) return dataset_ops.Dataset.range(100).map(map_fn) return dataset_ops.Dataset.range(5).flat_map(flat_map_fn) self.verify_error_on_save(build_ds, 500, errors.FailedPreconditionError) @combinations.generate(test_base.default_test_combinations()) def testSparseCore(self): def _map_fn(i): return sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) def _flat_map_fn(x): return dataset_ops.Dataset.from_tensor_slices( sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) def _build_ds(): return dataset_ops.Dataset.range(10).map(_map_fn).flat_map( _flat_map_fn) self.run_core_tests(_build_ds, 20)
class LocalWorkersTest(data_service_test_base.TestBase, parameterized.TestCase): """Tests reading from local workers if `target_workers` is `local`.""" @combinations.generate(test_base.default_test_combinations()) def testOneLocalWorker(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=1, num_remote_workers=5) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="local") self.assertDatasetProduces(ds, list(range(num_elements))) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testLocalWorkers(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") self.assertDatasetProduces(ds, num_local_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testRepeatedDataset(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_local_workers * num_repetitions * list(range(num_elements)), assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testPrefetchingDataset(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") ds = ds.prefetch(10) self.assertDatasetProduces(ds, expected_output=num_local_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testMultipleEpochs(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") for _ in range(10): self.assertDatasetProduces(ds, num_local_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testDynamicSharding(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 100 ds = self.make_distributed_range_dataset( num_elements, cluster, processing_mode=ShardingPolicy.DYNAMIC, target_workers="LOCAL") self.assertDatasetProduces(ds, list(range(num_elements)), assert_items_equal=True) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testEmptyDataset(self, num_local_workers, num_remote_workers): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 0 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") self.assertDatasetProduces(ds, []) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_local_workers=[1, 3], num_remote_workers=[0, 3]))) def testNonLocalRead(self, num_local_workers, num_remote_workers): """This test ensures the remote workers are running and producing data.""" cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=num_local_workers, num_remote_workers=num_remote_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="any") num_workers = num_local_workers + num_remote_workers self.assertDatasetProduces(ds, num_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testNoLocalWorker(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=0, num_remote_workers=3) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, target_workers="LOCAL") with self.assertRaisesRegex( errors.InvalidArgumentError, "Local reads require local tf.data workers, but no local worker is " "found."): self.getDatasetOutput(ds) @combinations.generate(test_base.default_test_combinations()) def testInconsistentTargetWorkers(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=3, num_remote_workers=3) ds = dataset_ops.Dataset.range(10) datasets = [ self.make_distributed_dataset(ds, cluster, job_name="test_job", target_workers=target_workers) for target_workers in ["AUTO", "ANY", "LOCAL"] ] with self.assertRaisesRegex( errors.InvalidArgumentError, "but there is already an existing job with that name using " "target_workers <AUTO>."): for dataset in datasets: self.getDatasetOutput(dataset) @combinations.generate(test_base.default_test_combinations()) def testAnonymousJobWithDifferentTargetWorkers(self): num_local_workers, num_remote_workers = (3, 3) cluster = multi_process_cluster.MultiProcessCluster( num_local_workers, num_remote_workers) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) datasets = { target_workers: self.make_distributed_dataset(ds, cluster, target_workers=target_workers) for target_workers in ["AUTO", "ANY", "LOCAL"] } num_workers = num_local_workers + num_remote_workers self.assertDatasetProduces(datasets["AUTO"], num_workers * list(range(num_elements)), assert_items_equal=True) self.assertDatasetProduces(datasets["ANY"], num_workers * list(range(num_elements)), assert_items_equal=True) self.assertDatasetProduces(datasets["LOCAL"], num_local_workers * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.default_test_combinations()) def testCoordinatedRead(self): cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=3, num_remote_workers=3) ds = dataset_ops.Dataset.range(10).repeat() ds = self.make_distributed_dataset(ds, cluster, job_name="test_job", consumer_index=0, num_consumers=3, target_workers="LOCAL") with self.assertRaisesRegex( errors.InvalidArgumentError, "Coordinated reads require non-local workers"): self.getDatasetOutput(ds)
class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.default_test_combinations()) def testOptimizationStatefulFunction(self): dataset = dataset_ops.Dataset.range(10).map( lambda _: random_ops.random_uniform([])).batch(10) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) get_next = self.getNext(dataset) self.evaluate(get_next()) # TODO(b/123902160) @combinations.generate(test_base.graph_only_combinations()) def testOptimizationLargeInputFromTensor(self): input_t = array_ops.placeholder(dtypes.int32, (None, None, None)) dataset = dataset_ops.Dataset.from_tensors(input_t) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)}) self.evaluate(get_next) # TODO(b/123902160) @combinations.generate(test_base.graph_only_combinations()) def testOptimizationLargeInputFromTensorSlices(self): input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None)) dataset = dataset_ops.Dataset.from_tensor_slices(input_t) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)}) self.evaluate(get_next) @combinations.generate(test_base.default_test_combinations()) def testOptimizationNestedDataset(self): def flat_map_fn(_): dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"])) dataset = dataset.skip(0) # Should be removed by noop elimination dataset = dataset.cache() return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.flat_map(flat_map_fn) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.noop_elimination = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[0]) @combinations.generate(test_base.default_test_combinations()) def testOptimizationNestedDatasetWithModifiedRetval(self): def flat_map_fn(_): dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset.apply(testing.assert_next(["MapAndBatch"])) # Should be fused by map and batch fusion dataset = dataset.map(lambda x: x) dataset = dataset.batch(1) return dataset dataset = dataset_ops.Dataset.range(1) dataset = dataset.flat_map(flat_map_fn) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.map_and_batch_fusion = True dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[[0]]) @combinations.generate( combinations.times(test_base.default_test_combinations(), _disable_intra_op_parallelism_test_combinations())) def testOptimizationDisableIntraOpParallelism(self, dataset_fn, expected_output): os.environ[ "TF_DATA_EXPERIMENT_OPT_IN"] = "disable_intra_op_parallelism" os.environ["TF_JOB_NAME"] = "test_job" dataset = dataset_fn() dataset = dataset.apply(testing.assert_next(["MaxIntraOpParallelism"])) self.assertDatasetProduces(dataset, expected_output=expected_output) del os.environ["TF_DATA_EXPERIMENT_OPT_IN"] del os.environ["TF_JOB_NAME"] @combinations.generate(test_base.default_test_combinations()) def testOptimizationThreadPoolDataset(self): dataset = dataset_ops.Dataset.range(10).batch(10) dataset = threadpool.override_threadpool( dataset, threadpool.PrivateThreadPool( 2, display_name="private_thread_pool_%d" % 2)) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[list(range(10))], requires_initialization=True) # Reference variables are not supported in eager mode. @combinations.generate( combinations.times(test_base.graph_only_combinations(), _captured_refvar_test_combinations())) def testOptimizationWithCapturedRefVar(self, dataset_fn): """Tests that default optimizations are disabled with ref variables.""" variable = variable_scope.get_variable("v", initializer=0, use_resource=False) assign_op = variable.assign_add(1) # Check that warning is logged. warnings.simplefilter("always") with warnings.catch_warnings(record=True) as w: unoptimized_dataset = dataset_fn(variable) options = dataset_ops.Options() options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.noop_elimination = True options.experimental_optimization.map_and_batch_fusion = True optimized_dataset = unoptimized_dataset.with_options(options) optimized_it = dataset_ops.make_initializable_iterator( optimized_dataset) self.assertGreaterEqual(len(w), 1) graph_rewrites = options._graph_rewrites() expected = ( "tf.data graph rewrites are not compatible with " "tf.Variable. The following rewrites will be disabled: %s." " To enable rewrites, use resource variables instead by " "calling `tf.enable_resource_variables()` at the start of the " "program." % (", ".join(graph_rewrites.enabled + graph_rewrites.default))) self.assertTrue(any(expected in str(warning) for warning in w)) # Check that outputs are the same in the optimized and unoptimized cases, # when the variable value is changing. unoptimized_it = dataset_ops.make_initializable_iterator( unoptimized_dataset) with ops.control_dependencies([assign_op]): unoptimized_output = unoptimized_it.get_next() optimized_output = optimized_it.get_next() self.evaluate(variable.initializer) self.evaluate((unoptimized_it.initializer, optimized_it.initializer)) while True: try: unoptimized, optimized = self.evaluate( (unoptimized_output, optimized_output)) self.assertEqual(unoptimized, optimized) except errors.OutOfRangeError: break @combinations.generate(test_base.default_test_combinations()) def testOptimizationDefault(self): """Tests the optimization settings by default.""" options = dataset_ops.Options() expected_optimizations_enabled = [] expected_optimizations_disabled = [] expected_optimizations_default = [ "map_and_batch_fusion", "noop_elimination", "shuffle_and_repeat_fusion", ] graph_rewrites = options._graph_rewrites() self.assertEqual(set(graph_rewrites.enabled), set(expected_optimizations_enabled)) self.assertEqual(set(graph_rewrites.disabled), set(expected_optimizations_disabled)) self.assertEqual(set(graph_rewrites.default), set(expected_optimizations_default)) options.experimental_optimization.apply_default_optimizations = True graph_rewrites = options._graph_rewrites() self.assertEqual(set(graph_rewrites.enabled), set(expected_optimizations_enabled)) self.assertEqual(set(graph_rewrites.disabled), set(expected_optimizations_disabled)) self.assertEqual(set(graph_rewrites.default), set(expected_optimizations_default)) options.experimental_optimization.apply_default_optimizations = False expected_optimizations_default = [] graph_rewrites = options._graph_rewrites() self.assertEqual(set(graph_rewrites.enabled), set(expected_optimizations_enabled)) self.assertEqual(set(graph_rewrites.disabled), set(expected_optimizations_disabled)) self.assertEqual(set(graph_rewrites.default), set(expected_optimizations_default)) @combinations.generate(test_base.default_test_combinations()) def testOptimizationEnabled(self): """Tests the optimization settings by enabling all.""" options = dataset_ops.Options() options.experimental_optimization.filter_fusion = True options.experimental_optimization.filter_with_random_uniform_fusion = True options.experimental_optimization.hoist_random_uniform = True options.experimental_optimization.map_and_batch_fusion = True options.experimental_optimization.map_and_filter_fusion = True options.experimental_optimization.map_parallelization = True options.experimental_optimization.map_fusion = True options.experimental_optimization.noop_elimination = True options.experimental_optimization.parallel_batch = True options.experimental_optimization.shuffle_and_repeat_fusion = True options.experimental_optimization.map_vectorization.enabled = True options.experimental_optimization.autotune_buffers = True options.experimental_deterministic = False options.experimental_stats.latency_all_edges = True options.experimental_slack = True expected_optimizations_enabled = [ "filter_fusion", "filter_with_random_uniform_fusion", "hoist_random_uniform", "map_and_batch_fusion", "map_and_filter_fusion", "map_parallelization", "map_fusion", "noop_elimination", "parallel_batch", "shuffle_and_repeat_fusion", "map_vectorization", "inject_prefetch", "make_sloppy", "latency_all_edges", "slack", ] expected_optimizations_disabled = [] expected_optimizations_default = [] graph_rewrites = options._graph_rewrites() self.assertEqual(set(graph_rewrites.enabled), set(expected_optimizations_enabled)) self.assertEqual(set(graph_rewrites.disabled), set(expected_optimizations_disabled)) self.assertEqual(set(graph_rewrites.default), set(expected_optimizations_default)) @combinations.generate(test_base.default_test_combinations()) def testOptimizationDisabled(self): """Tests the optimization settings by disabling all.""" options = dataset_ops.Options() options.experimental_optimization.filter_fusion = False options.experimental_optimization.filter_with_random_uniform_fusion = False options.experimental_optimization.hoist_random_uniform = False options.experimental_optimization.map_and_batch_fusion = False options.experimental_optimization.map_and_filter_fusion = False options.experimental_optimization.map_parallelization = False options.experimental_optimization.map_fusion = False options.experimental_optimization.noop_elimination = False options.experimental_optimization.parallel_batch = False options.experimental_optimization.shuffle_and_repeat_fusion = False options.experimental_optimization.map_vectorization.enabled = False options.experimental_optimization.autotune = False options.experimental_deterministic = True options.experimental_stats.latency_all_edges = False options.experimental_slack = False expected_optimizations_enabled = [] expected_optimizations_disabled = [ "filter_fusion", "filter_with_random_uniform_fusion", "hoist_random_uniform", "map_and_batch_fusion", "map_and_filter_fusion", "map_parallelization", "map_fusion", "noop_elimination", "parallel_batch", "shuffle_and_repeat_fusion", "map_vectorization", "inject_prefetch", "make_sloppy", "latency_all_edges", "slack", ] expected_optimizations_default = [] graph_rewrites = options._graph_rewrites() self.assertEqual(set(graph_rewrites.enabled), set(expected_optimizations_enabled)) self.assertEqual(set(graph_rewrites.disabled), set(expected_optimizations_disabled)) self.assertEqual(set(graph_rewrites.default), set(expected_optimizations_default)) @combinations.generate(test_base.default_test_combinations()) def testAutotuningDefaults(self): options = dataset_ops.Options() # Check defaults autotune, algorithm, cpu_budget = options._autotune_settings() self.assertTrue(autotune) self.assertEqual(algorithm, optimization_options._AutotuneAlgorithm.HILL_CLIMB) self.assertEqual(cpu_budget, 0) @combinations.generate(test_base.default_test_combinations()) def testAutotuningBufferSizes(self): options = dataset_ops.Options() options.experimental_optimization.autotune_buffers = True self.assertIn("inject_prefetch", options._graph_rewrites().enabled) autotune, algorithm, cpu_budget = options._autotune_settings() self.assertTrue(autotune) self.assertEqual( algorithm, optimization_options._AutotuneAlgorithm.GRADIENT_DESCENT) self.assertEqual(cpu_budget, 0)