def testElementSpecEagerMode(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id) self.assertDatasetProduces(ds, list(range(num_elements)))
def testSharedJobNameMultiIteration(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds1 = self.make_distributed_range_dataset( num_elements, cluster, job_name="job_name") ds2 = self.make_distributed_range_dataset( num_elements, cluster, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) # iteration 2 self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds1, [])
def testZipDifferentProcessingModesDatasetsSharedJobName(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch", job_name="job_name") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset( ds2, cluster, processing_mode="parallel_epochs", job_name="job_name") ds = dataset_ops.Dataset.zip((ds1, ds2)) with self.assertRaisesRegex(errors.FailedPreconditionError, "but there is already an existing job"): self.getDatasetOutput(ds)
def testConcatenate(self, num_workers): cluster = data_service_test_base.TestCluster(num_workers=num_workers) a = dataset_ops.Dataset.range(100) b = dataset_ops.Dataset.range(100, 200) ds = a.concatenate(b) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") assert_items_equal = (num_workers > 1) self.assertDatasetProduces(ds, list(range(200)), assert_items_equal=assert_items_equal)
def testExplicitProtocolFromDatasetId(self): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") range_ds = dataset_ops.Dataset.range(10) dataset_id = data_service_ops.register_dataset(cluster.dispatcher.target, range_ds) ds = data_service_ops.from_dataset_id( dataset_id=dataset_id, processing_mode="parallel_epochs", element_spec=range_ds.element_spec, service=cluster.dispatcher.target, data_transfer_protocol="grpc") self.assertDatasetProduces(ds, list(range(10)))
def testRegisteringDatasetAsTfFunction(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) register_func = def_function.function(self.register_dataset) dataset_id = register_func( (constant_op.constant("grpc"), constant_op.constant(cluster.dispatcher_address())), ds) from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
def testNestedZip(self): num_elements = 10 cluster = data_service_test_base.TestCluster(num_workers=1) a = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip((a, a)) ds = dataset_ops.Dataset.zip((a, a, ds, a)) ds = self.make_distributed_dataset(ds, cluster, processing_mode="distributed_epoch") b = list(range(10)) self.assertDatasetProduces(ds, list(zip(b, b, zip(b, b), b)))
def testSnapshot(self, already_written): num_workers = 3 cluster = data_service_test_base.TestCluster(num_workers=num_workers) ds = dataset_ops.Dataset.range(100) ds = ds.snapshot(self.get_temp_dir()) if already_written: # Materialize the snapshot. self.getDatasetOutput(ds) ds = self._make_dynamic_sharding_dataset(ds, cluster) error_regex = "Splitting is not implemented for snapshot datasets" with self.assertRaisesRegex(errors.UnimplementedError, error_regex): self.getDatasetOutput(ds)
def testGcUnusedJob(self, job_name): cluster = data_service_test_base.TestCluster( num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster, job_name=job_name) it = iter(ds) self.assertEqual(next(it).numpy(), 0) self.assertEqual(cluster.workers[0].num_tasks(), 1) del it while cluster.workers[0].num_tasks() > 0: time.sleep(0.1)
def testImbalancedZipMultiWorker(self): smaller_num_elements = 200 larger_num_elements = 1000 cluster = data_service_test_base.TestCluster(num_workers=3) a = dataset_ops.Dataset.range(smaller_num_elements) b = dataset_ops.Dataset.range(larger_num_elements) ds = dataset_ops.Dataset.zip((a, b)) ds = self._make_dynamic_sharding_dataset(ds, cluster) # Cannot assert specific elements because the range datasets are split # nondeterministically and may not line up. self.assertLen(self.getDatasetOutput(ds), smaller_num_elements)
def testChooseFromDatasets(self, num_workers): cluster = data_service_test_base.TestCluster(num_workers=num_workers) words = [b"foo", b"bar", b"baz"] datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] choice_array = np.random.randint(3, size=(15,), dtype=np.int64) choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) ds = dataset_ops.Dataset.choose_from_datasets(datasets, choice_dataset) ds = self._make_dynamic_sharding_dataset(ds, cluster) expected = [words[i] for i in choice_array] assert_items_equal = (num_workers > 1) self.assertDatasetProduces( ds, expected, assert_items_equal=assert_items_equal)
def testImbalancedZip(self): smaller_num_elements = 200 larger_num_elements = 1000 cluster = data_service_test_base.TestCluster(num_workers=1) a = dataset_ops.Dataset.range(smaller_num_elements) b = dataset_ops.Dataset.range(larger_num_elements) ds = dataset_ops.Dataset.zip((a, b)) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces( ds, list(zip(range(smaller_num_elements), range(smaller_num_elements))))
def testFlatMapWithRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=3) ds = dataset_ops.Dataset.range(5) def flat_map_fn(_): return dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"]).repeat(10) ds = ds.flat_map(flat_map_fn) ds = self._make_dynamic_sharding_dataset(ds, cluster) self.assertDatasetProduces(ds, [b"a", b"b", b"c"] * 50, assert_items_equal=True)
def testConsumerRestart(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_consumers = 3 ds = self.make_coordinated_read_dataset(cluster, num_consumers) get_next = self.getNext(ds, requires_initialization=False) _ = [self.evaluate(get_next()) for _ in range(20)] ds2 = self.make_coordinated_read_dataset(cluster, num_consumers) with self.assertRaisesRegex(errors.FailedPreconditionError, "current round has already reached"): get_next_ds2 = self.getNext(ds2, requires_initialization=False) _ = [self.evaluate(get_next_ds2()) for _ in range(20)] cluster.stop_workers()
def testDispatcherRestartDuringReading(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] for _ in range(num_elements // 2): results.append(next(iterator).numpy()) cluster.restart_dispatcher() for elem in iterator: results.append(elem.numpy()) self.assertEqual(list(range(num_elements)), results)
def testDispatcherStop(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = self.make_distributed_range_dataset(num_elements, cluster) iterator = iter(ds) results = [] results.append(next(iterator).numpy()) cluster.stop_dispatcher() # After the dispatcher dies, the worker should continue providing the rest # of the dataset's elements. for _ in range(num_elements - 1): results.append(next(iterator).numpy()) self.assertEqual(results, list(range(num_elements)))
def testFromDatasetIdOmitsElementSpecAndCompression(self, compression): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") dataset = dataset_ops.Dataset.from_tensor_slices( list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset=dataset, compression=compression) dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id) self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
def testFromDatasetIdWrongElementSpec(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, wrong_spec) with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)())
def testVariables(self, use_resource): cluster = data_service_test_base.TestCluster(num_workers=1) if not use_resource: with variable_scope.variable_scope("foo", use_resource=False): v = variables.VariableV1(10, dtype=dtypes.int64) else: v = variables.Variable(10, dtype=dtypes.int64) ds = dataset_ops.Dataset.range(3) ds = ds.map(lambda x: x + v) ds = self.make_distributed_dataset(ds, cluster) self.evaluate(v.initializer) self.assertDatasetProduces( ds, list(range(10, 13)), requires_initialization=True)
def testElementSpecGraphMode(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) with self.assertRaisesRegex( ValueError, "In graph mode element_spec must be provided manually."): ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id)
def testZipDifferentProcessingModesDatasets(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds1 = dataset_ops.Dataset.range(num_elements) ds1 = self.make_distributed_dataset( ds1, cluster, processing_mode="distributed_epoch") ds2 = dataset_ops.Dataset.range(num_elements) ds2 = self.make_distributed_dataset( ds2, cluster, processing_mode="parallel_epochs") ds = dataset_ops.Dataset.zip((ds1, ds2)) self.assertDatasetProduces( ds, list(zip(range(num_elements), range(num_elements))), assert_items_equal=True)
def testFromDatasetIdMultipleComponents(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = self.register_dataset(cluster.dispatcher_address(), ds) from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster, dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"])
def testCancellation(self): self.skipTest("b/162521601") sleep_microseconds = int(1e6) * 1000 cluster = data_service_test_base.TestCluster(num_workers=1) # Create a dataset which produces the first element quickly, and the second # element slowly. Fetching the first element triggers prefetching of the # second element, which we should be able to cancel. slow = dataset_ops.Dataset.range(1) slow = slow.apply(testing.sleep(sleep_microseconds)) ds = dataset_ops.Dataset.range(1).concatenate(slow) ds = self.make_distributed_dataset(ds, cluster) ds = ds.prefetch(1) get_next = self.getNext(ds) self.assertEqual(0, self.evaluate(get_next()))
def testFromDatasetIdDoesntRequireElementSpec(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False, data_transfer_protocol="grpc") num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(), ds) ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id) self.assertDatasetProduces(ds, list(range(num_elements)))
def testGcAndRecreate(self): cluster = data_service_test_base.TestCluster( num_workers=3, job_gc_check_interval_ms=50, job_gc_timeout_ms=20) num_elements = 1000 # Repeatedly create and garbage-collect the same job. for _ in range(3): ds = self.make_distributed_range_dataset( num_elements, cluster, job_name="test") it = iter(ds) for _ in range(50): next(it) del it # Wait for the task to be garbage-collected on all workers. while cluster.num_tasks_on_workers() > 0: time.sleep(0.1)
def testResourceOnWrongDevice(self): cluster = data_service_test_base.TestCluster(num_workers=1) with ops.device(self._devices[0]): initializer = self.lookupTableInitializer("keyvaluetensor", [10, 11]) table = lookup_ops.StaticHashTable(initializer, -1) self.evaluate(lookup_ops.tables_initializer()) with ops.device(self._devices[1]): ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) with self.assertRaisesRegex( errors.FailedPreconditionError, "Serialization error while trying to register a dataset"): ds = self.make_distributed_dataset(ds, cluster) self.getDatasetOutput(ds, requires_initialization=True)
def testForeverRepeat(self): cluster = data_service_test_base.TestCluster(num_workers=2) num_elements = 20 elements_to_read = 1000 ds = dataset_ops.Dataset.range(num_elements).repeat() ds = self._make_dynamic_sharding_dataset(ds, cluster) get_next = self.getNext(ds) results = {} for _ in range(elements_to_read): val = self.evaluate(get_next()) if val not in results: results[val] = 0 results[val] += 1 for i in range(num_elements): self.assertGreater(results[i], elements_to_read / num_elements / 2)
def testFiniteV1(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = self.make_distributed_dataset(ds, cluster, job_name="test", consumer_index=0, num_consumers=1) with self.assertRaisesRegex( errors.FailedPreconditionError, "Encountered end of sequence on a " "round-robin read iterator"): self.getDatasetOutput(ds)
def testGroupByWindow(self): # Verify that split providers are not propagated into iterators created for # the reduce datasets created by the reduce_fn in group_by_window. cluster = data_service_test_base.TestCluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) def reduce_fn(_, window): return dataset_ops.Dataset.zip((window, dataset_ops.Dataset.range(100))) ds = ds.group_by_window(lambda x: 0, reduce_fn, window_size=3) ds = self._make_dynamic_sharding_dataset(ds, cluster) # This will fail if the tensor_slices split provider ispropagated into the # `reduce_fn`, since the `zip` requires either 0 or 2 split providers. self.getDatasetOutput(ds)
def _testCompressionMismatch(self, dataset): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") with mock.patch.object(compat, "forward_compatible", return_value=False): dataset_id = data_service_ops._register_dataset( cluster.dispatcher.target, dataset=dataset, compression=None) # `compression` is "AUTO" by default. dataset = data_service_ops._from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id, element_spec=dataset.element_spec) with self.assertRaises(errors.InvalidArgumentError): self.getDatasetOutput(dataset)