class EnablingTF2Behavior(test.TestCase, parameterized.TestCase): def __init__(self, methodName): super().__init__(methodName) self._set_default_seed = False @combinations.generate(test_base.v1_only_combinations()) def test_tf1_enable_tf2_behaviour(self): self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) v2_compat.enable_v2_behavior() self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled()) v2_compat.disable_v2_behavior() self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) @combinations.generate(test_base.v1_only_combinations()) def test_tf1_disable_tf2_behaviour(self): self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) v2_compat.disable_v2_behavior() self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) v2_compat.enable_v2_behavior() self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled()) @combinations.generate(test_base.v2_only_combinations()) def test_tf2_enable_tf2_behaviour(self): self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled()) v2_compat.enable_v2_behavior() self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled()) v2_compat.disable_v2_behavior() self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) @combinations.generate(test_base.v2_only_combinations()) def test_tf2_disable_tf2_behaviour(self): self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled()) v2_compat.disable_v2_behavior() self.assertFalse(tf2.enabled()) self.assertFalse(_pywrap_tf2.is_enabled()) v2_compat.enable_v2_behavior() self.assertTrue(tf2.enabled()) self.assertTrue(_pywrap_tf2.is_enabled())
class FromSparseTensorSlicesCheckpointTest( checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): def _build_sparse_tensor_slice_dataset(self, slices): # pylint: disable=g-complex-comprehension indices = np.array([[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], dtype=np.int64) values = np.array([val for s in slices for val in s], dtype=np.float64) # pylint: enable=g-complex-comprehension dense_shape = np.array( [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) @combinations.generate( combinations.times(test_base.v1_only_combinations(), checkpoint_test_base.default_test_combinations())) def test(self, verify_fn): slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] verify_fn(self, lambda: self._build_sparse_tensor_slice_dataset(slices), num_outputs=9, sparse_tensors=True)
class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.eager_only_combinations(), data_service_test_base.all_cluster_configurations()) ) def testDistributeBasic(self, work_dir, fault_tolerant_mode): cluster = self.create_cluster(num_workers=1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode) num_elements = 10 ds = self.make_distributed_range_dataset(10, cluster) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDistributeSparse(self): cluster = self.create_cluster(num_workers=1) element = sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) ds = dataset_ops.Dataset.from_tensors(element) ds = self.make_distributed_dataset(ds, cluster) results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds] self.assertAllEqual(results, [[0]]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeRagged(self): cluster = self.create_cluster(num_workers=1) ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8]) ds = ds.map(math_ops.range) ds = ds.apply(batching.dense_to_ragged_batch(2)) ds = self.make_distributed_dataset(ds, cluster) results = [elem.to_tensor() for elem in ds] self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]]) self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]]) self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]]) @combinations.generate(test_base.eager_only_combinations()) def testDifferentShuffleOrders(self): random_seed.set_random_seed(None) num_elements = 100 cluster = self.create_cluster(num_workers=2) ds = dataset_ops.Dataset.range(num_elements) ds = ds.shuffle(num_elements) ds = self.make_distributed_dataset(ds, cluster) output = [elem.numpy() for elem in ds] # The output will be two sequences of range(num_elements) # non-deterministically interleaved together. If the orders of the elements # were the same, first_order and second_order computed below will be equal. first_order = {} second_order = {} for element in output: if element in first_order: second_order[element] = len(second_order) else: first_order[element] = len(first_order) self.assertNotEqual(first_order, second_order) @combinations.generate(test_base.eager_only_combinations()) def testMultipleEpochs(self): cluster = self.create_cluster(num_workers=1) num_elements = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) for _ in range(10): self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds]) @combinations.generate(test_base.eager_only_combinations()) def testRepeatedDataset(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 num_repetitions = 5 ds = self.make_distributed_range_dataset(num_elements, cluster) ds = ds.repeat(num_repetitions) self.assertDatasetProduces(ds, expected_output=num_repetitions * list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testConcurrentEpoch(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 num_datasets = 3 iterators = [] results = [] for _ in range(num_datasets): ds = self.make_distributed_range_dataset(num_elements, cluster) iterators.append(iter(ds)) results.append([]) for _ in range(num_elements): for dataset_ind in range(num_datasets): result = next(iterators[dataset_ind]).numpy() results[dataset_ind].append(result) for result in results: self.assertEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testSharedEpoch(self): self.skipTest("Not yet implemented") cluster = self.create_cluster(num_workers=1) num_elements = 10 num_iterators = 3 ds = self.make_distributed_range_dataset(num_elements, cluster) result = [] iterators = [] for _ in range(num_iterators): iterators.append(iter(ds)) # Alternate reading between the iterators. for _ in range(2): for it in iterators: result.append(next(it).numpy()) # Drain the rest of the elements. for it in iterators: for elem in it: result.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testMultiWorker(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster) results = [elem.numpy() for elem in ds] self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testMaxOutstandingRequests(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 ds = self.make_distributed_range_dataset(num_elements, cluster, max_outstanding_requests=1) self.assertCountEqual(num_workers * list(range(num_elements)), self.getDatasetOutput(ds)) @combinations.generate(test_base.eager_only_combinations()) def testInsideFunction(self): num_workers = 3 cluster = self.create_cluster(num_workers=num_workers) num_elements = 10 @def_function.function def f(): ds = self.make_distributed_range_dataset(num_elements, cluster) result = tensor_array_ops.TensorArray(dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 for elem in ds: result = result.write(i, elem) i += 1 return result.stack() result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobName(self): cluster = self.create_cluster(num_workers=1) num_elements = 1000 def make_ds(): return dataset_ops.Dataset.range(num_elements).shuffle( num_elements) ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name") iter1 = iter(ds1) iter2 = iter(ds2) results = [] for _ in range(num_elements // 5): results.append(next(iter1).numpy()) results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testDifferentJobNames(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name1") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) # iteration 2 self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds1, []) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameRepeat(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 num_repetitions = 3 ds1 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds1 = ds1.repeat(num_repetitions) ds2 = self.make_distributed_range_dataset(num_elements, cluster, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] iter1 = iter(ds1) iter2 = iter(ds2) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter1).numpy()) for _ in range((num_elements * num_repetitions) // 5): results.append(next(iter2).numpy()) for elem in iter1: results.append(elem.numpy()) for elem in iter2: results.append(elem.numpy()) self.assertCountEqual(num_repetitions * list(range(num_elements)), results) @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5]))) def testRoundRobin(self, num_workers, num_consumers): cluster = self.create_cluster(num_workers=num_workers) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) ds = dataset_ops.Dataset.range(10000000) ds = ds.repeat() consumers = [] for consumer_index in range(num_consumers): consumers.append( self.make_distributed_dataset(ds, cluster, job_name="test", consumer_index=consumer_index, num_consumers=num_consumers)) # Use parallel interleave to read from consumers in parallel. ds = dataset_ops.Dataset.from_tensor_slices(consumers) ds = ds.interleave(lambda x: x, cycle_length=num_consumers, num_parallel_calls=num_consumers) ds = ds.take(1000) results = self.getDatasetOutput(ds, requires_initialization=True) for i in range(0, len(results), num_consumers): self.assertEqual(0, results[i] % num_consumers) # Check that each group of `num_consumers` results are consecutive. for offset in range(1, num_consumers): if i + offset < len(results): self.assertEqual(results[i] + offset, results[i + offset]) @combinations.generate(test_base.default_test_combinations()) def testRoundRobinBucketizing(self): # Tests a common use case for round robin reads. At each step, all # consumers should get batches with the same bucket size. cluster = self.create_cluster(num_workers=4) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_elements = 100 low_bucket_max = 30 mid_bucket_max = 60 bucket_boundaries = [low_bucket_max, mid_bucket_max] batch_size = 10 num_consumer_hosts = 3 replicas_per_consumer_host = 5 num_consumers = num_consumer_hosts * replicas_per_consumer_host bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1) # Set up the dataset that will run on the tf.data workers. ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32) ds = ds.shuffle(num_elements) ds = ds.repeat() ds = ds.apply( grouping.bucket_by_sequence_length(lambda x: x, bucket_boundaries, bucket_batch_sizes, drop_remainder=True)) ds = ds.apply( grouping.group_by_window( lambda x: math_ops.cast(x[1], dtypes.int64), lambda _, x: dataset_ops.Dataset.from_tensors(x), window_size=num_consumers)) ds = ds.flat_map(lambda x: x) # Set up the per-consumer-host datasets. During each global step, we pull # `replicas_per_consumer_host` batches from each of these datasets. host_datasets = [] for host_index in range(num_consumer_hosts): per_replica_datasets = [] for i in range(replicas_per_consumer_host): consumer_index = host_index * replicas_per_consumer_host + i per_replica_datasets.append( self.make_distributed_dataset( ds, cluster, job_name="test", consumer_index=consumer_index, num_consumers=num_consumers)) host_dataset = dataset_ops.Dataset.from_tensor_slices( per_replica_datasets) host_dataset = host_dataset.interleave( lambda x: x, cycle_length=len(per_replica_datasets), num_parallel_calls=len(per_replica_datasets), deterministic=True) host_datasets.append(host_dataset) # Use parallel interleave to read from host datasets in parallel. ds = dataset_ops.Dataset.from_tensor_slices(host_datasets) ds = ds.interleave(lambda x: x, block_length=replicas_per_consumer_host, cycle_length=len(host_datasets), num_parallel_calls=len(host_datasets), deterministic=True) num_rounds = 10 get_next = self.getNext(ds, requires_initialization=True) results = [] for _ in range(num_rounds * num_consumers): results.append(self.evaluate(get_next())) def get_bucket(elem): bucket_ind = 0 while bucket_ind < len(bucket_boundaries ) and elem >= bucket_boundaries[bucket_ind]: bucket_ind += 1 return bucket_ind # Check that the batches for each step contain elements from the same # bucket. for i in range(0, len(results), num_consumers): batches = results[num_consumers * i:num_consumers * (i + 1)] bucket_inds = [get_bucket(batch[0]) for batch in batches] for bucket_ind in bucket_inds[1:]: self.assertEqual(bucket_inds[0], bucket_ind) @combinations.generate(test_base.v1_only_combinations()) def testRoundRobinFiniteV1(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster, job_name="test", consumer_index=0, num_consumers=1) with self.assertRaisesRegex( errors.FailedPreconditionError, "Encountered end of sequence on a " "round-robin read iterator"): self.getDatasetOutput(ds, requires_initialization=True) @combinations.generate(test_base.v2_only_combinations()) def testRoundRobinFiniteV2(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster, job_name="test", consumer_index=0, num_consumers=1) with self.assertRaisesRegex( errors.FailedPreconditionError, "Round robin reads " "require that the input dataset has infinite " "cardinality, but the dataset has cardinality " + str(num_elements)): self.getDatasetOutput(ds, requires_initialization=True) @combinations.generate( combinations.times(test_base.eager_only_combinations(), combinations.combine(job_name=[None, "test"]))) def testGcUnusedJob(self, job_name): cluster = self.create_cluster(num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster, job_name=job_name) it = iter(ds) self.assertEqual(next(it).numpy(), 0) self.assertEqual(cluster.num_tasks_on_worker(), 1) del it while cluster.num_tasks_on_worker() > 0: time.sleep(0.1) @combinations.generate(test_base.eager_only_combinations()) def testDontGcUsedJob(self): cluster = self.create_cluster(num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 10 it1 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test1")) it2 = iter( self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable self.make_distributed_range_dataset(num_elements, cluster, job_name="test2")) self.assertEqual(2, cluster.num_tasks_on_worker()) del it1 del it2 # Check that only the first job is gced. The second job will not be gced # because there is still an outstanding iterator for it. while cluster.num_tasks_on_worker() > 1: time.sleep(0.1) self.assertEqual(1, cluster.num_tasks_on_worker()) @combinations.generate(test_base.eager_only_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) cluster = self.create_cluster(num_workers=1) def dataset_fn(delay_ms): def interleave_fn(x): ds = dataset_ops.Dataset.from_tensors(x) if math_ops.equal(x, 0): ds = ds.apply(testing.sleep(delay_ms * 1000)) else: ds = ds.apply(testing.sleep(0)) return ds ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10) opts = dataset_ops.Options() opts.experimental_deterministic = False ds = ds.with_options(opts) ds = self.make_distributed_dataset(ds, cluster) return ds self.checkDeterminism(dataset_fn=dataset_fn, expect_determinism=False, expected_elements=elements) def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( lambda _: random_ops.random_uniform(())) options = dataset_ops.Options() options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) cluster = self.create_cluster(num_workers=3) ds = self.make_distributed_dataset(ds, cluster) next(iter(ds)) @combinations.generate( combinations.times( test_base.eager_only_combinations(), combinations.combine(external_state_policy=[ distribute_options.ExternalStatePolicy.IGNORE, distribute_options.ExternalStatePolicy.WARN ]))) def testStatefulNoError(self, external_state_policy): self.run_stateful(external_state_policy) @combinations.generate(test_base.eager_only_combinations()) def testStatefulError(self): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(distribute_options.ExternalStatePolicy.FAIL) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochTensorSlices(self): cluster = self.create_cluster(num_workers=2) vals = [5, 1, 2, 4] ds = dataset_ops.Dataset.from_tensor_slices(vals) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, vals, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochInterleave(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave( lambda x: dataset_ops.Dataset.from_tensor_slices([x])) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochParallelInterleave(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.interleave( lambda x: dataset_ops.Dataset.from_tensor_slices([x]), num_parallel_calls=dataset_ops.AUTOTUNE) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochFlatMap(self): cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, elements, assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochRepeat(self): cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, num_repeats * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochForeverRepeat(self): cluster = self.create_cluster(num_workers=2) num_elements = 20 elements_to_read = 1000 ds = dataset_ops.Dataset.range(num_elements).repeat() ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") it = iter(ds) results = {} for _ in range(elements_to_read): val = next(it).numpy() if val not in results: results[val] = 0 results[val] += 1 for i in range(num_elements): self.assertGreater(results[i], elements_to_read / num_elements / 2) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochForeverRepeatFewElements(self): num_workers = 5 cluster = self.create_cluster(num_workers=num_workers) # Less than the number of workers, so that some workers get zero elements on # the first repetition. num_elements = 1 ds = dataset_ops.Dataset.range(num_elements).repeat() ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") it = iter(ds) for _ in range(100): self.assertEqual(next(it).numpy(), 0) # Stop all but one worker and check that we can still read. for i in range(num_workers - 1): cluster.workers[i]._stop() for _ in range(100): self.assertEqual(next(it).numpy(), 0) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochShuffleAndRepeat(self): cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).shuffle( num_elements).repeat(num_repeats) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, num_repeats * list(range(num_elements)), assert_items_equal=True) def testDistributeFromInterleave(self): cluster = self.create_cluster(num_workers=1) ds = dataset_ops.Dataset.range(2) def interleave_fn(_): dataset = dataset_ops.Dataset.range(2) self.make_distributed_dataset(dataset, cluster) return dataset ds = ds.interleave(interleave_fn, cycle_length=2) self.assertDatasetProduces(ds, [0, 0, 1, 1]) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpoch(self): cluster = self.create_cluster(num_workers=2) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "service must be a string"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service=1)) @combinations.generate(test_base.eager_only_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "service must not be empty"): ds = ds.apply( data_service_ops.distribute(processing_mode="parallel_epochs", service="")) @combinations.generate(test_base.eager_only_combinations()) def testDistributeInvalidProcessingMode(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "invalid is not a valid processing mode"): ds = ds.apply( data_service_ops.distribute(processing_mode="invalid", service="grpc://localhost:5000")) @combinations.generate(test_base.eager_only_combinations()) def testZipDifferentProcessingModesDatasets(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs") ds = dataset_ops.Dataset.zip((ds1, ds2)) self.assertDatasetProduces(ds, list( zip(range(num_elements), range(num_elements))), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testZipDifferentProcessingModesDatasetsSharedJobName(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch", job_name="job_name") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset(ds2, cluster, processing_mode="parallel_epochs", job_name="job_name") ds = dataset_ops.Dataset.zip((ds1, ds2)) with self.assertRaisesRegex(errors.FailedPreconditionError, "but there is already an existing job"): self.getDatasetOutput(ds) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetId(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdMultipleComponents(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = data_service_ops.register_dataset(cluster.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"]) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdWrongElementSpec(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.target, ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, wrong_spec) with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.eager_only_combinations()) def testFromDatasetIdNotRegistered(self): cluster = self.create_cluster(num_workers=1) dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, element_spec) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"): self.evaluate(self.getNext(from_dataset_id_ds)()) @combinations.generate(test_base.default_test_combinations()) def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 cluster = self.create_cluster(num_workers=1) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = self.make_distributed_dataset(ds, cluster) ds = ds.prefetch(1) get_next = self.getNext(ds, requires_initialization=True) self.assertEqual(0, self.evaluate(get_next())) # Without properly implemented cancellation, we will hang here while trying # to garbage collect the dataset iterator. @combinations.generate(test_base.eager_only_combinations()) def testRegisterEquivalentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = self.create_cluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.target, ds_1) id_2 = data_service_ops.register_dataset(cluster.target, ds_2) self.assertEqual(id_1.numpy(), id_2.numpy()) @combinations.generate(test_base.eager_only_combinations()) def testRegisterDifferentDatasets(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(20) cluster = self.create_cluster(num_workers=1) id_1 = data_service_ops.register_dataset(cluster.target, ds_1) id_2 = data_service_ops.register_dataset(cluster.target, ds_2) self.assertNotEqual(id_1.numpy(), id_2.numpy()) @combinations.generate(test_base.default_test_combinations()) def testDistributedEpochOnZippedDataset(self): ds_1 = dataset_ops.Dataset.range(10) ds_2 = dataset_ops.Dataset.range(10) cluster = self.create_cluster(num_workers=1) ds_3 = dataset_ops.Dataset.zip((ds_1, ds_2)) ds_3 = self.make_distributed_dataset( ds_3, cluster, processing_mode="distributed_epoch") error_regex = "Cannot create a split provider for dataset " + \ "of type ZipDataset" with self.assertRaisesRegex(errors.UnimplementedError, error_regex): self.getDatasetOutput(ds_3, requires_initialization=True) @combinations.generate(test_base.default_test_combinations()) def testDistributedEpochOnDistributedDataset(self): cluster_1 = self.create_cluster(num_workers=1) cluster_2 = self.create_cluster(num_workers=1) num_sizes = 10 size_repeats = 5 numbers = [1 * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(numbers) ds = self.make_distributed_dataset(ds, cluster_1, processing_mode="parallel_epochs") ds = ds.map(lambda x: x + 1) ds = self.make_distributed_dataset(ds, cluster_2, processing_mode="distributed_epoch") error_regex = "Cannot create a split provider for dataset " + \ "of type DataServiceDataset" with self.assertRaisesRegex(errors.UnimplementedError, error_regex): self.getDatasetOutput(ds, requires_initialization=True) @combinations.generate(test_base.eager_only_combinations()) def testTwoLevelDistribute(self): cluster_1_size = 3 cluster_1 = self.create_cluster(num_workers=cluster_1_size) cluster_2 = self.create_cluster(num_workers=1) num_sizes = 10 size_repeats = 5 strings = ["a" * i for i in range(num_sizes)] * size_repeats ds = dataset_ops.Dataset.from_tensor_slices(strings) ds = ds.shuffle(len(strings)) ds = self.make_distributed_dataset(ds, cluster_1) # Large enough so that all strings of the same size are windowed together. window_size = cluster_1_size * size_repeats batch_size = size_repeats def key_func(x): return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64) ds = ds.apply( grouping.group_by_window( key_func=key_func, reduce_func=lambda _, x: x.batch(batch_size), window_size=window_size)) ds = self.make_distributed_dataset(ds, cluster_2) it = iter(ds) for _ in range(num_sizes): element = next(it).numpy() for _ in range(1, cluster_1_size): self.assertAllEqual(next(it).numpy(), element) self.assertEmpty(list(it)) @combinations.generate( combinations.times(test_base.eager_only_combinations())) def testDistributeLargeGraph(self): cluster = self.create_cluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [tensor])
class CoordinatedReadTest(data_service_test_base.TestBase, parameterized.TestCase): @combinations.generate( combinations.times( test_base.default_test_combinations(), combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5]))) def testBasic(self, num_workers, num_consumers): cluster = data_service_test_base.TestCluster(num_workers=num_workers) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) ds = self.make_coordinated_read_dataset(cluster, num_consumers) ds = ds.take(100) results = self.getDatasetOutput(ds) self.checkCoordinatedReadGroups(results, num_consumers) @combinations.generate( combinations.times(test_base.default_test_combinations())) def testConsumerRestart(self): cluster = data_service_test_base.TestCluster(num_workers=1) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_consumers = 3 ds = self.make_coordinated_read_dataset(cluster, num_consumers) ds = ds.take(20) self.getDatasetOutput(ds) ds2 = self.make_coordinated_read_dataset(cluster, num_consumers) ds2 = ds2.take(20) with self.assertRaisesRegex(errors.FailedPreconditionError, "current round has already reached"): self.getDatasetOutput(ds2) @combinations.generate(test_base.default_test_combinations()) def testBucketizing(self): # Tests a common use case for round robin reads. At each step, all # consumers should get batches with the same bucket size. cluster = data_service_test_base.TestCluster(num_workers=4) # Round robin reads can cause slow cluster shutdown. data_service_test_base.GLOBAL_CLUSTERS.add(cluster) num_elements = 100 low_bucket_max = 30 mid_bucket_max = 60 bucket_boundaries = [low_bucket_max, mid_bucket_max] batch_size = 10 num_consumer_hosts = 3 replicas_per_consumer_host = 5 num_consumers = num_consumer_hosts * replicas_per_consumer_host bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1) # Set up the dataset that will run on the tf.data workers. ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32) ds = ds.shuffle(num_elements) ds = ds.repeat() ds = ds.apply( grouping.bucket_by_sequence_length( lambda x: x, bucket_boundaries, bucket_batch_sizes, drop_remainder=True)) ds = ds.apply( grouping.group_by_window( lambda x: math_ops.cast(x[1], dtypes.int64), lambda _, x: dataset_ops.Dataset.from_tensors(x), window_size=num_consumers)) ds = ds.flat_map(lambda x: x) # Set up the per-consumer-host datasets. During each global step, we pull # `replicas_per_consumer_host` batches from each of these datasets. host_datasets = [] for host_index in range(num_consumer_hosts): per_replica_datasets = [] for i in range(replicas_per_consumer_host): consumer_index = host_index * replicas_per_consumer_host + i per_replica_datasets.append( self.make_distributed_dataset( ds, cluster, job_name="test", consumer_index=consumer_index, num_consumers=num_consumers)) host_dataset = dataset_ops.Dataset.from_tensor_slices( per_replica_datasets) host_dataset = host_dataset.interleave( lambda x: x, cycle_length=len(per_replica_datasets), num_parallel_calls=len(per_replica_datasets), deterministic=True) host_datasets.append(host_dataset) # Use parallel interleave to read from host datasets in parallel. ds = dataset_ops.Dataset.from_tensor_slices(host_datasets) ds = ds.interleave( lambda x: x, block_length=replicas_per_consumer_host, cycle_length=len(host_datasets), num_parallel_calls=len(host_datasets), deterministic=True) num_rounds = 4 get_next = self.getNext(ds) results = [] for i in range(num_rounds * num_consumers): results.append(self.evaluate(get_next())) def get_bucket(elem): bucket_ind = 0 while bucket_ind < len( bucket_boundaries) and elem >= bucket_boundaries[bucket_ind]: bucket_ind += 1 return bucket_ind # Check that the batches for each step contain elements from the same # bucket. for i in range(0, len(results), num_consumers): batches = results[num_consumers * i:num_consumers * (i + 1)] bucket_inds = [get_bucket(batch[0]) for batch in batches] for bucket_ind in bucket_inds[1:]: self.assertEqual( bucket_inds[0], bucket_ind, "Batches: {}, Buckets: {}".format(batches, bucket_inds)) @combinations.generate(test_base.v1_only_combinations()) def testFiniteV1(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset( ds, cluster, job_name="test", consumer_index=0, num_consumers=1) with self.assertRaisesRegex( errors.FailedPreconditionError, "Encountered end of sequence on a " "round-robin read iterator"): self.getDatasetOutput(ds) @combinations.generate(test_base.v2_only_combinations()) def testFiniteV2(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset( ds, cluster, job_name="test", consumer_index=0, num_consumers=1) with self.assertRaisesRegex( errors.FailedPreconditionError, "Round robin reads " "require that the input dataset has infinite " "cardinality, but the dataset has cardinality " + str(num_elements)): self.getDatasetOutput(ds)
class MultiDeviceIteratorTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate( combinations.times(test_base.v1_only_combinations(), combinations.combine(num_inits=[0, 1, 42]))) def testInitOnly(self, num_inits): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"]) config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config): for _ in range(num_inits): self.evaluate(multi_device_iterator.initializer) @combinations.generate(test_base.v1_only_combinations()) def testBasic(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"]) config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config): self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.v1_only_combinations()) def testOneOnSameDevice(self): with ops.device("/cpu:0"): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:0", "/cpu:1"]) config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=config): self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.v1_only_combinations()) def testRepeatDevices(self): with ops.device("/cpu:0"): dataset = dataset_ops.Dataset.range(20) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"]) config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config): self.evaluate(multi_device_iterator.initializer) for i in range(0, 20, 4): elements = multi_device_iterator.get_next() elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) self.assertEqual(i + 2, self.evaluate(elem_on_3)) self.assertEqual(i + 3, self.evaluate(elem_on_4)) with self.assertRaises(errors.OutOfRangeError): elements = multi_device_iterator.get_next() elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements self.evaluate(elem_on_1) self.evaluate(elem_on_2) self.evaluate(elem_on_3) self.evaluate(elem_on_4) @combinations.generate(test_base.v1_only_combinations()) def testNotFullyDivisible(self): dataset = dataset_ops.Dataset.range(9) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"]) config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config): self.evaluate(multi_device_iterator.initializer) for i in range(0, 8, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) elem_on_1 = multi_device_iterator.get_next("/cpu:1") self.assertEqual(8, self.evaluate(elem_on_1)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.v1_only_combinations()) def testGetNextAsOptional(self): if context.executing_eagerly(): return dataset = dataset_ops.Dataset.range(9) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"]) elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional() elem_on_1_has_value_t = elem_on_1.has_value() elem_on_1_t = elem_on_1.get_value() elem_on_2_has_value_t = elem_on_2.has_value() elem_on_2_t = elem_on_2.get_value() config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config) as sess: self.evaluate(multi_device_iterator.initializer) for i in range(0, 8, 2): elem_on_1_has_value, elem_on_1_value = sess.run( [elem_on_1_has_value_t, elem_on_1_t]) self.assertTrue(elem_on_1_has_value) self.assertEqual(i, elem_on_1_value) elem_on_2_has_value, elem_on_2_value = sess.run( [elem_on_2_has_value_t, elem_on_2_t]) self.assertTrue(elem_on_2_has_value) self.assertEqual(i + 1, elem_on_2_value) elem_on_1_has_value, elem_on_1_value = sess.run( [elem_on_1_has_value_t, elem_on_1_t]) self.assertTrue(elem_on_1_has_value) self.assertEqual(8, elem_on_1_value) self.assertFalse(self.evaluate(elem_on_1_has_value_t)) self.assertFalse(self.evaluate(elem_on_2_has_value_t)) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(elem_on_1_t) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(elem_on_2_t) @combinations.generate(test_base.v1_only_combinations()) def testUneven(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4) config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config): self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1 = multi_device_iterator.get_next("/cpu:1") self.assertEqual(i, self.evaluate(elem_on_1)) for i in range(0, 10, 2): elem_on_2 = multi_device_iterator.get_next("/cpu:2") self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.v1_only_combinations()) def testMultipleInitializationsGraph(self): if context.executing_eagerly(): return with ops.device("/cpu:0"): epoch = array_ops.placeholder(dtypes.int64, shape=[]) dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000) dataset2 = dataset_ops.Dataset.range(1000) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4) elem_on_1, elem_on_2 = multi_device_iterator.get_next() init_op = multi_device_iterator.initializer config = config_pb2.ConfigProto(device_count={"CPU": 3}) pool = config.session_inter_op_thread_pool.add() pool.num_threads = 2 with session.Session(config=config) as sess: for i in range(1000): sess.run(init_op, feed_dict={epoch: i}) self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1, elem_on_2])) @combinations.generate(test_base.v1_only_combinations()) def testMultipleInitializationsEager(self): if not context.executing_eagerly(): return with ops.device("/cpu:0"): dataset1 = dataset_ops.Dataset.range(1000) dataset2 = dataset_ops.Dataset.range(1000) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) for _ in range(5): multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4) elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2])) @combinations.generate(test_base.v1_only_combinations()) def testBasicGpu(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/gpu:0"]) config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) with self.test_session(config=config): self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.v1_only_combinations()) def testUnevenGpu(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") dataset = dataset_ops.Dataset.range(10) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4) config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) with self.test_session(config=config): self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1 = multi_device_iterator.get_next("/cpu:1") self.assertEqual(i, self.evaluate(elem_on_1)) for i in range(0, 10, 2): elem_on_2 = multi_device_iterator.get_next("/gpu:0") self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2) @combinations.generate(test_base.v1_only_combinations()) def testGetNextAsOptionalGpu(self): if not test_util.is_gpu_available() or context.executing_eagerly(): self.skipTest("No GPU available") dataset = dataset_ops.Dataset.range(9) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/gpu:0"]) elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional() elem_on_1_has_value_t = elem_on_1.has_value() elem_on_1_t = elem_on_1.get_value() elem_on_2_has_value_t = elem_on_2.has_value() elem_on_2_t = elem_on_2.get_value() config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) with self.test_session(config=config) as sess: self.evaluate(multi_device_iterator.initializer) for i in range(0, 8, 2): elem_on_1_has_value, elem_on_1_value = sess.run( [elem_on_1_has_value_t, elem_on_1_t]) self.assertTrue(elem_on_1_has_value) self.assertEqual(i, elem_on_1_value) elem_on_2_has_value, elem_on_2_value = sess.run( [elem_on_2_has_value_t, elem_on_2_t]) self.assertTrue(elem_on_2_has_value) self.assertEqual(i + 1, elem_on_2_value) elem_on_1_has_value, elem_on_1_value = sess.run( [elem_on_1_has_value_t, elem_on_1_t]) self.assertTrue(elem_on_1_has_value) self.assertEqual(8, elem_on_1_value) self.assertFalse(self.evaluate(elem_on_1_has_value_t)) self.assertFalse(self.evaluate(elem_on_2_has_value_t)) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(elem_on_1_t) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(elem_on_2_t) @combinations.generate(test_base.v1_only_combinations()) def testOptimization(self): dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"])) dataset = dataset.skip(0) # this should be optimized away dataset = dataset.cache() options = dataset_ops.Options() options.experimental_optimization.noop_elimination = True dataset = dataset.with_options(options) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"]) config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config): self.evaluate(multi_device_iterator.initializer) for i in range(0, 10, 2): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.assertEqual(i, self.evaluate(elem_on_1)) self.assertEqual(i + 1, self.evaluate(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): elem_on_1, elem_on_2 = multi_device_iterator.get_next() self.evaluate(elem_on_1) self.evaluate(elem_on_2)
class CheckpointInputPipelineHookTest(test.TestCase, parameterized.TestCase): @staticmethod def _model_fn(features, labels, mode, config): del labels del mode del config global_step = training_util.get_or_create_global_step() update_global_step_op = global_step.assign_add(1) latest_feature = variables.VariableV1(0, name='latest_feature', dtype=dtypes.int64) store_latest_feature_op = latest_feature.assign(features) ops.add_to_collection('my_vars', global_step) ops.add_to_collection('my_vars', latest_feature) return model_fn.EstimatorSpec(mode='train', train_op=control_flow_ops.group([ update_global_step_op, store_latest_feature_op ]), loss=constant_op.constant(2.0)) def _read_vars(self, model_dir): """Returns (global_step, latest_feature).""" with ops.Graph().as_default() as g: ckpt_path = checkpoint_management.latest_checkpoint(model_dir) meta_filename = ckpt_path + '.meta' saver_lib.import_meta_graph(meta_filename) saver = saver_lib.Saver() with self.session(graph=g) as sess: saver.restore(sess, ckpt_path) return sess.run(ops.get_collection('my_vars')) def _build_iterator_saver_hook(self, est): return iterator_ops.CheckpointInputPipelineHook(est) @combinations.generate(test_base.v1_only_combinations()) def testReturnDatasetFromInputFn(self): def _input_fn(): return dataset_ops.Dataset.range(10) est = estimator.Estimator(model_fn=self._model_fn) est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) @combinations.generate(test_base.v1_only_combinations()) def testBuildIteratorInInputFn(self): def _input_fn(): ds = dataset_ops.Dataset.range(10) iterator = ds.make_one_shot_iterator() return iterator.get_next() est = estimator.Estimator(model_fn=self._model_fn) est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) @combinations.generate(test_base.v1_only_combinations()) def testDoNotRestore(self): def _input_fn(): return dataset_ops.Dataset.range(10) est = estimator.Estimator(model_fn=self._model_fn) est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) # Hook not provided, input pipeline was not restored. est.train(_input_fn, steps=2) self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1)) @combinations.generate(test_base.v1_only_combinations()) def testRaiseErrorIfNoIterator(self): def _input_fn(): return constant_op.constant(1, dtype=dtypes.int64) est = estimator.Estimator(model_fn=self._model_fn) with self.assertRaises(ValueError): est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])