class EnablingTF2Behavior(test.TestCase, parameterized.TestCase): def __init__(self, methodName): super().__init__(methodName) self._set_default_seed = False @combinations.generate(test_base.v1_only_combinations()) def test_tf1_enable_tf2_behaviour(self): self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) v2_compat.enable_v2_behavior() self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled()) v2_compat.disable_v2_behavior() self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) @combinations.generate(test_base.v1_only_combinations()) def test_tf1_disable_tf2_behaviour(self): self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) v2_compat.disable_v2_behavior() self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) v2_compat.enable_v2_behavior() self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled()) @combinations.generate(test_base.v2_only_combinations()) def test_tf2_enable_tf2_behaviour(self): self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled()) v2_compat.enable_v2_behavior() self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled()) v2_compat.disable_v2_behavior() self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) @combinations.generate(test_base.v2_only_combinations()) def test_tf2_disable_tf2_behaviour(self): self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled()) v2_compat.disable_v2_behavior() self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) v2_compat.enable_v2_behavior() self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled())
class SaveCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase): def _build_ds(self): dataset = dataset_ops.Dataset.range(42) return dataset_ops._SaveDataset( dataset=dataset, path=self._save_dir, shard_func=None, compression=None) # This tests checkpointing for the _SaveDataset, which is internally # consumed in the save() function. The purpose of this test is to # thoroughly test the checkpointing functionality of the internal dataset. @combinations.generate( combinations.times(test_base.v2_only_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): verify_fn(self, self._build_ds, num_outputs=42) @combinations.generate(test_base.eager_only_combinations()) def testSaveCheckpointingAPI(self): dataset = dataset_ops.Dataset.range(40) checkpoint_args = {"directory": self._checkpoint_prefix, "max_to_keep": 50} dataset.save(self._save_dir, checkpoint_args=checkpoint_args) num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix))) # By default, we checkpoint every increment. Each checkpoint writes a # file containing the data and a file containing the index. There is # also an overall checkpoint file. Thus, we expect (2 * 40) + 1 files. self.assertEqual(81, num_checkpoint_files) @combinations.generate(test_base.eager_only_combinations()) def testSaveCheckpointingAPICustomCheckpointInterval(self): dataset = dataset_ops.Dataset.range(40) step_counter = variables.Variable(0, trainable=False) checkpoint_args = { "checkpoint_interval": 5, "step_counter": step_counter, "directory": self._checkpoint_prefix, "max_to_keep": 10, } dataset.save(self._save_dir, checkpoint_args=checkpoint_args) num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix))) # We expect (2 * 8) + 1 files. self.assertEqual(17, num_checkpoint_files) @combinations.generate(test_base.eager_only_combinations()) def testSaveCheckpointingAPIIncorrectArgs(self): dataset = dataset_ops.Dataset.range(42) checkpoint_args = { "directory": self._checkpoint_prefix, "incorrect_arg": "incorrect_arg" } with self.assertRaises(TypeError): dataset.save( dataset, self._save_dir, checkpoint_args=checkpoint_args)
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.eager_only_combinations(), data_service_test_base.all_cluster_configurations()) ) def testDistributeBasic(self, work_dir, fault_tolerant_mode): cluster = self.create_cluster(num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 10 ds = self.make_distributed_range_dataset(10, cluster) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDistributeSparse(self): cluster = self.create_cluster(num_workers=1) element = sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) ds = dataset_ops.Dataset.from_tensors(element) ds = self.make_distributed_dataset(ds, cluster) results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds] self.assertAllEqual(results, [[0]]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeRagged(self): cluster = self.create_cluster(num_workers=1) ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8]) ds = ds.map(math_ops.range) ds = ds.apply(batching.dense_to_ragged_batch(2)) ds = self.make_distributed_dataset(ds, cluster) results = [elem.to_tensor() for elem in ds] self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]]) self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]]) self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]]) @combinations.generate(test_base.eager_only_combinations()) def testDifferentShuffleOrders(self): random_seed.set_random_seed(None) num_elements = 100 cluster = self.create_cluster(num_workers=2) ds = dataset_ops.Dataset.range(num_elements) ds = ds.shuffle(num_elements) ds = self.make_distributed_dataset(ds, cluster) output = [elem.numpy() for elem in ds] # The output will be two sequences of range(num_elements) # non-deterministically interleaved together. If the orders of the elements # were the same, first_order and second_order computed below will be equal. first_order = {} second_order = {} for element in output: if element in first_order: second_order[element] = len(second_order) else: first_order[element] = len(first_order) self.assertNotEqual(first_order, second_order) @combinations.generate(test_base.eager_only_combinations()) def testMultipleEpochs(self): cluster = self.create_cluster(num_workers=1) num_elements = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) for _ in range(10): self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds]) @combinations.generate(test_base.eager_only_combinations()) def testRepeatedDataset(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster) ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_repetitions * list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testConcurrentEpoch(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 num_datasets = 3 iterators = [] results = [] for _ in range(num_datasets): ds = self.make_distributed_range_dataset(num_elements, cluster) iterators.append(iter(ds)) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = next(iterators[dataset_ind]).numpy() results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testSharedEpoch(self): self.skipTest("Not yet implemented") cluster = self.create_cluster(num_workers=1) num_elements = 10 num_iterators = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) result = [] iterators = [] for _ in range(num_iterators): iterators.append(iter(ds)) # Alternate reading between the iterators. for _ in range(2): for it in iterators: result.append(next(it).numpy()) # Drain the rest of the elements. for it in iterators: for elem in it: result.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testMultiWorker(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testMaxOutstandingRequests(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, max_outstanding_requests=1) self.assertCountEqual(num_workers * list(range(num_elements)), self.getDatasetOutput(ds)) @combinations.generate(test_base.eager_only_combinations()) def testInsideFunction(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 @def_function.function def f(): ds = self.make_distributed_range_dataset(num_elements, cluster) result = tensor_array_ops.TensorArray(dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in ds: result = result.write(i, elem) i += 1 return result.stack() result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobName(self): cluster = self.create_cluster(num_workers=1) num_elements = 1000 def make_ds(): return dataset_ops.Dataset.range(num_elements).shuffle( num_elements) ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") iter1 = iter(ds1) iter2 = iter(ds2) results = [] for _ in range(num_elements // 5): results.append(next(iter1).numpy()) results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDifferentJobNames(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name1") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) # iteration 2 self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds1, []) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameRepeat(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 num_repetitions = 3 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds1 = ds1.repeat(num_repetitions) ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] iter1 = iter(ds1) iter2 = iter(ds2) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter1).numpy()) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5]))) def testRoundRobin(self, num_workers, num_consumers): cluster = self.create_cluster(num_workers=num_workers) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) ds = dataset_ops.Dataset.range(10000000) ds = ds.repeat() consumers = [] for consumer_index in range(num_consumers): consumers.append( self.make_distributed_dataset(ds, cluster, job_name="test", consumer_index=consumer_index, num_consumers=num_consumers)) # Use parallel interleave to read from consumers in parallel. ds = dataset_ops.Dataset.from_tensor_slices(consumers) ds = ds.interleave(lambda x: x, cycle_length=num_consumers, num_parallel_calls=num_consumers) ds = ds.take(1000) results = self.getDatasetOutput(ds, requires_initialization=True) for i in range(0, len(results), num_consumers): self.assertEqual(0, results[i] % num_consumers) # Check that each group of `num_consumers` results are consecutive. for offset in range(1, num_consumers): if i + offset < len(results): self.assertEqual(results[i] + offset, results[i + offset]) @combinations.generate(test_base.default_test_combinations()) def testRoundRobinBucketizing(self): # Tests a common use case for round robin reads. At each step, all # consumers should get batches with the same bucket size. cluster = self.create_cluster(num_workers=4) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_elements = 100 low_bucket_max = 30 mid_bucket_max = 60 bucket_boundaries = [low_bucket_max, mid_bucket_max] batch_size = 10 num_consumer_hosts = 3 replicas_per_consumer_host = 5 num_consumers = num_consumer_hosts * replicas_per_consumer_host bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1) # Set up the dataset that will run on the tf.data workers. ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32) ds = ds.shuffle(num_elements) ds = ds.repeat() ds = ds.apply( grouping.bucket_by_sequence_length(lambda x: x, bucket_boundaries, bucket_batch_sizes, drop_remainder=True)) ds = ds.apply( grouping.group_by_window( lambda x: math_ops.cast(x[1], dtypes.int64), lambda _, x: dataset_ops.Dataset.from_tensors(x), window_size=num_consumers)) ds = ds.flat_map(lambda x: x) # Set up the per-consumer-host datasets. During each global step, we pull # `replicas_per_consumer_host` batches from each of these datasets. host_datasets = [] for host_index in range(num_consumer_hosts): per_replica_datasets = [] for i in range(replicas_per_consumer_host): consumer_index = host_index * replicas_per_consumer_host + i per_replica_datasets.append( self.make_distributed_dataset( ds, cluster, job_name="test", consumer_index=consumer_index, num_consumers=num_consumers)) host_dataset = dataset_ops.Dataset.from_tensor_slices( per_replica_datasets) host_dataset = host_dataset.interleave( lambda x: x, cycle_length=len(per_replica_datasets), num_parallel_calls=len(per_replica_datasets), deterministic=True) host_datasets.append(host_dataset) # Use parallel interleave to read from host datasets in parallel. ds = dataset_ops.Dataset.from_tensor_slices(host_datasets) ds = ds.interleave(lambda x: x, block_length=replicas_per_consumer_host, cycle_length=len(host_datasets), num_parallel_calls=len(host_datasets), deterministic=True) num_rounds = 10 get_next = self.getNext(ds, requires_initialization=True) results = [] for _ in range(num_rounds * num_consumers): results.append(self.evaluate(get_next())) def get_bucket(elem): bucket_ind = 0 while bucket_ind < len(bucket_boundaries ) and elem >= bucket_boundaries[bucket_ind]: bucket_ind += 1 return bucket_ind # Check that the batches for each step contain elements from the same # bucket. for i in range(0, len(results), num_consumers): batches = results[num_consumers * i:num_consumers * (i + 1)] bucket_inds = [get_bucket(batch[0]) for batch in batches] for bucket_ind in bucket_inds[1:]: self.assertEqual(bucket_inds[0], bucket_ind) @combinations.generate(test_base.v1_only_combinations()) def testRoundRobinFiniteV1(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster, job_name="test", consumer_index=0, num_consumers=1) with self.assertRaisesRegex( errors.FailedPreconditionError, "Encountered end of sequence on a " "round-robin read iterator"): self.getDatasetOutput(ds, requires_initialization=True) @combinations.generate(test_base.v2_only_combinations()) def testRoundRobinFiniteV2(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster, job_name="test", consumer_index=0, num_consumers=1) with self.assertRaisesRegex( errors.FailedPreconditionError, "Round robin reads " "require that the input dataset has infinite " "cardinality, but the dataset has cardinality " + str(num_elements)): self.getDatasetOutput(ds, requires_initialization=True) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(job_name=[None, "test"]))) def testGcUnusedJob(self, job_name): cluster = self.create_cluster(num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster, job_name=job_name) it = iter(ds) self.assertEqual(next(it).numpy(), 0) self.assertEqual(cluster.num_tasks_on_worker(), 1) del it while cluster.num_tasks_on_worker() > 0: time.sleep(0.1) @combinations.generate(test_base.eager_only_combinations()) def testDontGcUsedJob(self): cluster = self.create_cluster(num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 10 it1 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test1")) it2 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) self.assertEqual(2, cluster.num_tasks_on_worker()) del it1 del it2 # Check that only the first job is gced. The second job will not be gced # because there is still an outstanding iterator for it. while cluster.num_tasks_on_worker() > 1: time.sleep(0.1) self.assertEqual(1, cluster.num_tasks_on_worker()) @combinations.generate(test_base.eager_only_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) cluster = self.create_cluster(num_workers=1) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10) opts = dataset_ops.Options() opts.experimental_deterministic = False ds = ds.with_options(opts) ds = self.make_distributed_dataset(ds, cluster) return ds self.checkDeterminism(dataset_fn=dataset_fn, expect_determinism=False, expected_elements=elements) def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = dataset_ops.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) cluster = self.create_cluster(num_workers=3) ds = self.make_distributed_dataset(ds, cluster) next(iter(ds)) @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(external_state_policy=[ distribute_options.ExternalStatePolicy.IGNORE, distribute_options.ExternalStatePolicy.WARN ]))) def testStatefulNoError(self, external_state_policy): self.run_stateful(external_state_policy) @combinations.generate(test_base.eager_only_combinations()) def testStatefulError(self): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(distribute_options.ExternalStatePolicy.FAIL) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochTensorSlices(self): cluster = self.create_cluster(num_workers=2) vals = [5, 1, 2, 4] ds = dataset_ops.Dataset.from_tensor_slices(vals) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, vals, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochInterleave(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave( lambda x: dataset_ops.Dataset.from_tensor_slices([x])) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochParallelInterleave(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave( lambda x: dataset_ops.Dataset.from_tensor_slices([x]), num_parallel_calls=dataset_ops.AUTOTUNE) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochFlatMap(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochRepeat(self): cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, num_repeats * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochForeverRepeat(self): cluster = self.create_cluster(num_workers=2) num_elements = 20 elements_to_read = 1000 ds = dataset_ops.Dataset.range(num_elements).repeat() ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") it = iter(ds) results = {} for _ in range(elements_to_read): val = next(it).numpy() if val not in results: results[val] = 0 results[val] += 1 for i in range(num_elements): self.assertGreater(results[i], elements_to_read / num_elements / 2) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochForeverRepeatFewElements(self): num_workers = 5 cluster = self.create_cluster(num_workers=num_workers) # Less than the number of workers, so that some workers get zero elements on # the first repetition. num_elements = 1 ds = dataset_ops.Dataset.range(num_elements).repeat() ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") it = iter(ds) for _ in range(100): self.assertEqual(next(it).numpy(), 0) # Stop all but one worker and check that we can still read. for i in range(num_workers - 1): cluster.workers[i]._stop() for _ in range(100): self.assertEqual(next(it).numpy(), 0) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochShuffleAndRepeat(self): cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).shuffle( num_elements).repeat(num_repeats) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, num_repeats * list(range(num_elements)), assert_items_equal=True) def testDistributeFromInterleave(self): cluster = self.create_cluster(num_workers=1) ds = dataset_ops.Dataset.range(2) def interleave_fn(_): dataset = dataset_ops.Dataset.range(2) self.make_distributed_dataset(dataset, cluster) return dataset ds = ds.interleave(interleave_fn, cycle_length=2) self.assertDatasetProduces(ds, [0, 0, 1, 1]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpoch(self): cluster = self.create_cluster(num_workers=2) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "service must be a string"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=1)) @combinations.generate(test_base.eager_only_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "service must not be empty"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="")) @combinations.generate(test_base.eager_only_combinations()) def testDistributeInvalidProcessingMode(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "invalid is not a valid processing mode"): ds = ds.apply( data_service_ops.distribute(processing_mode="invalid", service="grpc://localhost:5000")) @combinations.generate(test_base.eager_only_combinations()) def testZipDifferentProcessingModesDatasets(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs") ds = dataset_ops.Dataset.zip((ds1, ds2)) self.assertDatasetProduces(ds, list( zip(range(num_elements), range(num_elements))), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testZipDifferentProcessingModesDatasetsSharedJobName(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch", job_name="job_name") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs", job_name="job_name") ds = dataset_ops.Dataset.zip((ds1, ds2)) with self.assertRaisesRegex(errors.FailedPreconditionError, "but there is already an existing job"): self.getDatasetOutput(ds) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetId(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdMultipleComponents(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = data_service_ops.register_dataset(cluster.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"]) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdWrongElementSpec(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.target, ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, wrong_spec) with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdNotRegistered(self): cluster = self.create_cluster(num_workers=1) dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, element_spec) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 cluster = self.create_cluster(num_workers=1) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = self.make_distributed_dataset(ds, cluster) ds = ds.prefetch(1) get_next = self.getNext(ds, requires_initialization=True) self.assertEqual(0, self.evaluate(get_next())) # Without properly implemented cancellation, we will hang here while trying # to garbage collect the dataset iterator. @combinations.generate(test_base.eager_only_combinations()) def testRegisterEquivalentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = self.create_cluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.target, ds_1) id_2 = data_service_ops.register_dataset(cluster.target, ds_2) self.assertEqual(id_1.numpy(), id_2.numpy()) @combinations.generate(test_base.eager_only_combinations()) def testRegisterDifferentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(20) cluster = self.create_cluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.target, ds_1) id_2 = data_service_ops.register_dataset(cluster.target, ds_2) self.assertNotEqual(id_1.numpy(), id_2.numpy()) @combinations.generate(test_base.default_test_combinations()) def testDistributedEpochOnZippedDataset(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = self.create_cluster(num_workers=1) ds_3 = dataset_ops.Dataset.zip((ds_1, ds_2)) ds_3 = self.make_distributed_dataset( ds_3, cluster, processing_mode="distributed_epoch") error_regex = "Cannot create a split provider for dataset " + \ "of type ZipDataset" with self.assertRaisesRegex(errors.UnimplementedError, error_regex): self.getDatasetOutput(ds_3, requires_initialization=True) @combinations.generate(test_base.default_test_combinations()) def testDistributedEpochOnDistributedDataset(self): cluster_1 = self.create_cluster(num_workers=1) cluster_2 = self.create_cluster(num_workers=1) num_sizes = 10 size_repeats = 5 numbers = [1 * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(numbers) ds = self.make_distributed_dataset(ds, cluster_1, processing_mode="parallel_epochs") ds = ds.map(lambda x: x + 1) ds = self.make_distributed_dataset(ds, cluster_2, processing_mode="distributed_epoch") error_regex = "Cannot create a split provider for dataset " + \ "of type DataServiceDataset" with self.assertRaisesRegex(errors.UnimplementedError, error_regex): self.getDatasetOutput(ds, requires_initialization=True) @combinations.generate(test_base.eager_only_combinations()) def testTwoLevelDistribute(self): cluster_1_size = 3 cluster_1 = self.create_cluster(num_workers=cluster_1_size) cluster_2 = self.create_cluster(num_workers=1) num_sizes = 10 size_repeats = 5 strings = ["a" * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(strings) ds = ds.shuffle(len(strings)) ds = self.make_distributed_dataset(ds, cluster_1) # Large enough so that all strings of the same size are windowed together. window_size = cluster_1_size * size_repeats batch_size = size_repeats def key_func(x): return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64) ds = ds.apply( grouping.group_by_window( key_func=key_func, reduce_func=lambda _, x: x.batch(batch_size), window_size=window_size)) ds = self.make_distributed_dataset(ds, cluster_2) it = iter(ds) for _ in range(num_sizes): element = next(it).numpy() for _ in range(1, cluster_1_size): self.assertAllEqual(next(it).numpy(), element) self.assertEmpty(list(it)) @combinations.generate( combinations.times(test_base.eager_only_combinations())) def testDistributeLargeGraph(self): cluster = self.create_cluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [tensor])
class CoordinatedReadTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5]))) def testBasic(self, num_workers, num_consumers): cluster = data_service_test_base.TestCluster(num_workers=num_workers) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) ds = self.make_coordinated_read_dataset(cluster, num_consumers) ds = ds.take(100) results = self.getDatasetOutput(ds) self.checkCoordinatedReadGroups(results, num_consumers) @combinations.generate( combinations.times(test_base.default_test_combinations())) def testConsumerRestart(self): cluster = data_service_test_base.TestCluster(num_workers=1) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_consumers = 3 ds = self.make_coordinated_read_dataset(cluster, num_consumers) ds = ds.take(20) self.getDatasetOutput(ds) ds2 = self.make_coordinated_read_dataset(cluster, num_consumers) ds2 = ds2.take(20) with self.assertRaisesRegex(errors.FailedPreconditionError, "current round has already reached"): self.getDatasetOutput(ds2) @combinations.generate(test_base.default_test_combinations()) def testBucketizing(self): # Tests a common use case for round robin reads. At each step, all # consumers should get batches with the same bucket size. cluster = data_service_test_base.TestCluster(num_workers=4) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_elements = 100 low_bucket_max = 30 mid_bucket_max = 60 bucket_boundaries = [low_bucket_max, mid_bucket_max] batch_size = 10 num_consumer_hosts = 3 replicas_per_consumer_host = 5 num_consumers = num_consumer_hosts * replicas_per_consumer_host bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1) # Set up the dataset that will run on the tf.data workers. ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32) ds = ds.shuffle(num_elements) ds = ds.repeat() ds = ds.apply( grouping.bucket_by_sequence_length( lambda x: x, bucket_boundaries, bucket_batch_sizes, drop_remainder=True)) ds = ds.apply( grouping.group_by_window( lambda x: math_ops.cast(x[1], dtypes.int64), lambda _, x: dataset_ops.Dataset.from_tensors(x), window_size=num_consumers)) ds = ds.flat_map(lambda x: x) # Set up the per-consumer-host datasets. During each global step, we pull # `replicas_per_consumer_host` batches from each of these datasets. host_datasets = [] for host_index in range(num_consumer_hosts): per_replica_datasets = [] for i in range(replicas_per_consumer_host): consumer_index = host_index * replicas_per_consumer_host + i per_replica_datasets.append( self.make_distributed_dataset( ds, cluster, job_name="test", consumer_index=consumer_index, num_consumers=num_consumers)) host_dataset = dataset_ops.Dataset.from_tensor_slices( per_replica_datasets) host_dataset = host_dataset.interleave( lambda x: x, cycle_length=len(per_replica_datasets), num_parallel_calls=len(per_replica_datasets), deterministic=True) host_datasets.append(host_dataset) # Use parallel interleave to read from host datasets in parallel. ds = dataset_ops.Dataset.from_tensor_slices(host_datasets) ds = ds.interleave( lambda x: x, block_length=replicas_per_consumer_host, cycle_length=len(host_datasets), num_parallel_calls=len(host_datasets), deterministic=True) num_rounds = 4 get_next = self.getNext(ds) results = [] for i in range(num_rounds * num_consumers): results.append(self.evaluate(get_next())) def get_bucket(elem): bucket_ind = 0 while bucket_ind < len( bucket_boundaries) and elem >= bucket_boundaries[bucket_ind]: bucket_ind += 1 return bucket_ind # Check that the batches for each step contain elements from the same # bucket. for i in range(0, len(results), num_consumers): batches = results[num_consumers * i:num_consumers * (i + 1)] bucket_inds = [get_bucket(batch[0]) for batch in batches] for bucket_ind in bucket_inds[1:]: self.assertEqual( bucket_inds[0], bucket_ind, "Batches: {}, Buckets: {}".format(batches, bucket_inds)) @combinations.generate(test_base.v1_only_combinations()) def testFiniteV1(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset( ds, cluster, job_name="test", consumer_index=0, num_consumers=1) with self.assertRaisesRegex( errors.FailedPreconditionError, "Encountered end of sequence on a " "round-robin read iterator"): self.getDatasetOutput(ds) @combinations.generate(test_base.v2_only_combinations()) def testFiniteV2(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset( ds, cluster, job_name="test", consumer_index=0, num_consumers=1) with self.assertRaisesRegex( errors.FailedPreconditionError, "Round robin reads " "require that the input dataset has infinite " "cardinality, but the dataset has cardinality " + str(num_elements)): self.getDatasetOutput(ds)
class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def _setup_files(self, inputs, linebreak='\n', compression_type=None): filenames = [] for i, file_rows in enumerate(inputs): fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) contents = linebreak.join(file_rows).encode('utf-8') if compression_type is None: with open(fn, 'wb') as f: f.write(contents) elif compression_type == 'GZIP': with gzip.GzipFile(fn, 'wb') as f: f.write(contents) elif compression_type == 'ZLIB': contents = zlib.compress(contents) with open(fn, 'wb') as f: f.write(contents) else: raise ValueError('Unsupported compression_type', compression_type) filenames.append(fn) return filenames def _make_test_datasets(self, inputs, **kwargs): # Test by comparing its output to what we could get with map->decode_csv filenames = self._setup_files(inputs) dataset_expected = core_readers.TextLineDataset(filenames) dataset_expected = dataset_expected.map( lambda l: parsing_ops.decode_csv(l, **kwargs)) dataset_actual = readers.CsvDataset(filenames, **kwargs) return (dataset_actual, dataset_expected) def _test_by_comparison(self, inputs, **kwargs): """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" dataset_actual, dataset_expected = self._make_test_datasets( inputs, **kwargs) self.assertDatasetsEqual(dataset_actual, dataset_expected) def _test_dataset( self, inputs, expected_output=None, expected_err_re=None, linebreak='\n', compression_type=None, # Used for both setup and parsing **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings filenames = self._setup_files(inputs, linebreak, compression_type) kwargs['compression_type'] = compression_type if expected_err_re is not None: # Verify that OpError is produced as expected with self.assertRaisesOpError(expected_err_re): dataset = readers.CsvDataset(filenames, **kwargs) self.getDatasetOutput(dataset) else: dataset = readers.CsvDataset(filenames, **kwargs) expected_output = [ tuple( v.encode('utf-8') if isinstance(v, str) else v for v in op) for op in expected_output ] self.assertDatasetProduces(dataset, expected_output) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_requiredFields(self): record_defaults = [[]] * 4 inputs = [['1,2,3,4']] self._test_by_comparison(inputs, record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_int(self): record_defaults = [[0]] * 4 inputs = [['1,2,3,4', '5,6,7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_float(self): record_defaults = [[0.0]] * 4 inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_string(self): record_defaults = [['']] * 4 inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withEmptyFields(self): record_defaults = [[0]] * 4 inputs = [[',,,', '1,1,1,', ',2,2,2']] self._test_dataset(inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4']] self._test_dataset( inputs, expected_err_re='Unquoted fields cannot have quotes inside', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errWithUnescapedQuotes(self): record_defaults = [['']] * 3 inputs = [['"a"b","c","d"']] self._test_dataset( inputs, expected_err_re= 'Quote inside a string has to be escaped by another quote', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_ignoreErrWithUnescapedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] filenames = self._setup_files(inputs) dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) dataset = dataset.apply(error_ops.ignore_errors()) self.assertDatasetProduces(dataset, [(b'e', b'f', b'g')]) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] filenames = self._setup_files(inputs) dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) dataset = dataset.apply(error_ops.ignore_errors()) self.assertDatasetProduces(dataset, [(b'e', b'f', b'g')]) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4']] self._test_by_comparison(inputs, record_defaults=record_defaults, use_quote_delim=False) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_mixedTypes(self): record_defaults = [ constant_op.constant([], dtype=dtypes.int32), constant_op.constant([], dtype=dtypes.float32), constant_op.constant([], dtype=dtypes.string), constant_op.constant([], dtype=dtypes.float64) ] inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withUseQuoteDelimFalse(self): record_defaults = [['']] * 4 inputs = [['1,2,"3,4"', '"5,6",7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults, use_quote_delim=False) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withFieldDelim(self): record_defaults = [[0]] * 4 inputs = [['1:2:3:4', '5:6:7:8']] self._test_by_comparison(inputs, record_defaults=record_defaults, field_delim=':') @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withNaValue(self): record_defaults = [[0]] * 4 inputs = [['1,NA,3,4', 'NA,6,7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults, na_value='NA') @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withSelectCols(self): record_defaults = [['']] * 2 inputs = [['1,2,3,4', '"5","6","7","8"']] self._test_by_comparison(inputs, record_defaults=record_defaults, select_cols=[1, 2]) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withSelectColsTooHigh(self): record_defaults = [[0]] * 2 inputs = [['1,2,3,4', '5,6,7,8']] self._test_dataset( inputs, expected_err_re='Expect 2 fields but have 1 in record', record_defaults=record_defaults, select_cols=[3, 4]) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withOneCol(self): record_defaults = [['NA']] inputs = [['0', '', '2']] self._test_dataset(inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withMultipleFiles(self): record_defaults = [[0]] * 4 inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withLeadingAndTrailingSpaces(self): record_defaults = [[0.0]] * 4 inputs = [['0, 1, 2, 3']] expected = [[0.0, 1.0, 2.0, 3.0]] self._test_dataset(inputs, expected, record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithMissingDefault(self): record_defaults = [[]] * 2 inputs = [['0,']] self._test_dataset( inputs, expected_err_re='Field 1 is required but missing in record!', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithFewerDefaultsThanFields(self): record_defaults = [[0.0]] * 2 inputs = [['0,1,2,3']] self._test_dataset( inputs, expected_err_re='Expect 2 fields but have more in record', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithMoreDefaultsThanFields(self): record_defaults = [[0.0]] * 5 inputs = [['0,1,2,3']] self._test_dataset( inputs, expected_err_re='Expect 5 fields but have 4 in record', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withHeader(self): record_defaults = [[0]] * 2 inputs = [['col1,col2', '1,2']] expected = [[1, 2]] self._test_dataset( inputs, expected, record_defaults=record_defaults, header=True, ) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withHeaderAndNoRecords(self): record_defaults = [[0]] * 2 inputs = [['col1,col2']] expected = [] self._test_dataset( inputs, expected, record_defaults=record_defaults, header=True, ) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithHeaderEmptyFile(self): record_defaults = [[0]] * 2 inputs = [[]] expected_err_re = "Can't read header of file" self._test_dataset( inputs, expected_err_re=expected_err_re, record_defaults=record_defaults, header=True, ) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withEmptyFile(self): record_defaults = [['']] * 2 inputs = [['']] # Empty file self._test_dataset(inputs, expected_output=[], record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithEmptyRecord(self): record_defaults = [['']] * 2 inputs = [['', '1,2']] # First record is empty self._test_dataset( inputs, expected_err_re='Expect 2 fields but have 1 in record', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withChainedOps(self): # Testing that one dataset can create multiple iterators fine. # `repeat` creates multiple iterators from the same C++ Dataset. record_defaults = [[0]] * 4 inputs = [['1,,3,4', '5,6,,8']] ds_actual, ds_expected = self._make_test_datasets( inputs, record_defaults=record_defaults) self.assertDatasetsEqual( ds_actual.repeat(5).prefetch(1), ds_expected.repeat(5).prefetch(1)) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withTypeDefaults(self): # Testing using dtypes as record_defaults for required fields record_defaults = [dtypes.float32, [0.0]] inputs = [['1.0,2.0', '3.0,4.0']] self._test_dataset( inputs, [[1.0, 2.0], [3.0, 4.0]], record_defaults=record_defaults, ) @combinations.generate(test_base.default_test_combinations()) def testMakeCsvDataset_fieldOrder(self): data = [[ '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19' ]] file_path = self._setup_files(data) ds = readers.make_csv_dataset(file_path, batch_size=1, shuffle=False, num_epochs=1) nxt = self.getNext(ds) result = list(self.evaluate(nxt()).values()) self.assertEqual(result, sorted(result)) ## The following tests exercise parsing logic for quoted fields @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withQuoted(self): record_defaults = [['']] * 4 inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] self._test_by_comparison(inputs, record_defaults=record_defaults) def testCsvDataset_withOneColAndQuotes(self): record_defaults = [['']] inputs = [['"0"', '"1"', '"2"']] self._test_dataset(inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withNewLine(self): # In this case, we expect it to behave differently from # TextLineDataset->map(decode_csv) since that flow has bugs record_defaults = [['']] * 4 inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] self._test_dataset(inputs, expected, record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withNewLineInUnselectedCol(self): record_defaults = [['']] inputs = [['1,"2\n3",4', '5,6,7']] self._test_dataset(inputs, expected_output=[['1'], ['5']], record_defaults=record_defaults, select_cols=[0]) @combinations.generate(test_base.v2_only_combinations()) def testCsvDataset_withExcludeCol(self): record_defaults = [['']] inputs = [['1,2,3', '5,6,7']] self._test_dataset(inputs, expected_output=[['1'], ['5']], record_defaults=record_defaults, exclude_cols=[1, 2]) @combinations.generate(test_base.v2_only_combinations()) def testCsvDataset_withSelectandExcludeCol(self): record_defaults = [['']] inputs = [['1,2,3', '5,6,7']] self._test_dataset( inputs, expected_err_re= 'Either select_cols or exclude_cols should be empty', record_defaults=record_defaults, select_cols=[0], exclude_cols=[1, 2]) @combinations.generate(test_base.v2_only_combinations()) def testCsvDataset_withExcludeColandRecordDefaultsTooLow(self): record_defaults = [['']] inputs = [['1,2,3', '5,6,7']] self._test_dataset( inputs, expected_err_re='Expect 1 fields but have more in record', record_defaults=record_defaults, exclude_cols=[0]) @combinations.generate(test_base.v2_only_combinations()) def testCsvDataset_withExcludeColandRecordDefaultsTooHigh(self): record_defaults = [['']] * 3 inputs = [['1,2,3', '5,6,7']] self._test_dataset( inputs, expected_err_re='Expect 3 fields but have 2 in record', record_defaults=record_defaults, exclude_cols=[0]) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withMultipleNewLines(self): # In this case, we expect it to behave differently from # TextLineDataset->map(decode_csv) since that flow has bugs record_defaults = [['']] * 4 inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] self._test_dataset(inputs, expected, record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_errorWithTerminateMidRecord(self): record_defaults = [['']] * 4 inputs = [['a,b,c,"a']] self._test_dataset( inputs, expected_err_re= 'Reached end of file without closing quoted field in record', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withEscapedQuotes(self): record_defaults = [['']] * 4 inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) ## Testing that parsing works with all buffer sizes, quoted/unquoted fields, ## and different types of line breaks @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withInvalidBufferSize(self): record_defaults = [['']] * 4 inputs = [['a,b,c,d']] self._test_dataset(inputs, expected_err_re='buffer_size should be positive', record_defaults=record_defaults, buffer_size=0) def _test_dataset_on_buffer_sizes(self, inputs, expected, linebreak, record_defaults, compression_type=None, num_sizes_to_test=20): # Testing reading with a range of buffer sizes that should all work. for i in list(range(1, 1 + num_sizes_to_test)) + [None]: self._test_dataset(inputs, expected, linebreak=linebreak, compression_type=compression_type, record_defaults=record_defaults, buffer_size=i) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withLF(self): record_defaults = [['NA']] * 3 inputs = [['abc,def,ghi', '0,1,2', ',,']] expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] self._test_dataset_on_buffer_sizes(inputs, expected, linebreak='\n', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withCR(self): # Test that when the line separator is '\r', parsing works with all buffer # sizes record_defaults = [['NA']] * 3 inputs = [['abc,def,ghi', '0,1,2', ',,']] expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] self._test_dataset_on_buffer_sizes(inputs, expected, linebreak='\r', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withCRLF(self): # Test that when the line separator is '\r\n', parsing works with all buffer # sizes record_defaults = [['NA']] * 3 inputs = [['abc,def,ghi', '0,1,2', ',,']] expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] self._test_dataset_on_buffer_sizes(inputs, expected, linebreak='\r\n', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withBufferSizeAndQuoted(self): record_defaults = [['NA']] * 3 inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], ['NA', 'NA', 'NA']] self._test_dataset_on_buffer_sizes(inputs, expected, linebreak='\n', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withCRAndQuoted(self): # Test that when the line separator is '\r', parsing works with all buffer # sizes record_defaults = [['NA']] * 3 inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], ['NA', 'NA', 'NA']] self._test_dataset_on_buffer_sizes(inputs, expected, linebreak='\r', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withCRLFAndQuoted(self): # Test that when the line separator is '\r\n', parsing works with all buffer # sizes record_defaults = [['NA']] * 3 inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], ['NA', 'NA', 'NA']] self._test_dataset_on_buffer_sizes(inputs, expected, linebreak='\r\n', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withGzipCompressionType(self): record_defaults = [['NA']] * 3 inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], ['NA', 'NA', 'NA']] self._test_dataset_on_buffer_sizes(inputs, expected, linebreak='\r\n', compression_type='GZIP', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withZlibCompressionType(self): record_defaults = [['NA']] * 3 inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], ['NA', 'NA', 'NA']] self._test_dataset_on_buffer_sizes(inputs, expected, linebreak='\r\n', compression_type='ZLIB', record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_withScalarDefaults(self): record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4 inputs = [[',,,', '1,1,1,', ',2,2,2']] self._test_dataset(inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], record_defaults=record_defaults) @combinations.generate(test_base.default_test_combinations()) def testCsvDataset_with2DDefaults(self): record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4 inputs = [[',,,', '1,1,1,', ',2,2,2']] if context.executing_eagerly(): err_spec = errors.InvalidArgumentError, ( 'Each record default should be at ' 'most rank 1') else: err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2' with self.assertRaisesWithPredicateMatch(*err_spec): self._test_dataset(inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], record_defaults=record_defaults) def testCsvDataset_immutableParams(self): inputs = [['a,b,c', '1,2,3', '4,5,6']] filenames = self._setup_files(inputs) select_cols = ['a', 'c'] _ = readers.make_csv_dataset(filenames, batch_size=1, select_columns=select_cols) self.assertAllEqual(select_cols, ['a', 'c'])