def testFromDatasetId(self): """Tests cross-trainer cache with `register_dataset`/`from_dataset_id`.""" cluster = self._create_cluster(num_workers=1) dataset = dataset_ops.Dataset.range(10000000).repeat() dataset_id1 = data_service_ops.register_dataset( cluster.dispatcher.target, dataset, dataset_id="dataset_id") dataset1 = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id1, element_spec=dataset.element_spec, job_name="job", cross_trainer_cache=data_service_ops.CrossTrainerCache( trainer_id="Trainer 1")) self.assertDatasetProduces(dataset1.take(10), list(range(10))) dataset_id2 = data_service_ops.register_dataset( cluster.dispatcher.target, dataset, dataset_id="dataset_id") dataset2 = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id2, element_spec=dataset.element_spec, job_name="job", cross_trainer_cache=data_service_ops.CrossTrainerCache( trainer_id="Trainer 2")) self.assertDatasetProduces(dataset2.take(10), list(range(10)))
def testEmptyJobNameFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset_ops.Dataset.range(10)) with self.assertRaisesRegex(ValueError, "job_name must not be empty"): data_service_ops.from_dataset_id(dataset_id=dataset_id, processing_mode="parallel_epochs", service=cluster.dispatcher.target, job_name="")
def testNonStringJobNameFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset_ops.Dataset.range(10)) with self.assertRaisesRegex(ValueError, "`job_name` must be a string"): data_service_ops.from_dataset_id( dataset_id=dataset_id, processing_mode="parallel_epochs", service=cluster.dispatcher.target, job_name=constant_op.constant("foo"))
def testElementSpecMixedMode(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=data_service_test_base.NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 dataset = dataset_ops.Dataset.range(num_elements) @def_function.function def get_dataset_id(): return data_service_ops.register_dataset( cluster.dispatcher_address(), dataset) dataset_id = get_dataset_id() dataset_id_val = tensor_util.constant_value(dataset_id) with self.assertRaisesRegex( ValueError, f"Failed to fetch element spec for dataset id {dataset_id_val} from " "tf.data service. If the dataset was registered in graph mode or " "inside a tf.function, the `element_spec` must be specified as an " "argument to `from_dataset_id`."): dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher_address(), dataset_id=dataset_id)
def testFromDatasetIdOmitsCompression(self, compression): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") dataset = dataset_ops.Dataset.from_tensor_slices( list("abcdefghijklmnopqrstuvwxyz")) def to_upper(x): return script_ops.numpy_function( func=lambda x: x.decode("utf-8").upper(), inp=[x], Tout=dtypes.string) dataset = dataset.map(to_upper, num_parallel_calls=dataset_ops.AUTOTUNE) with mock.patch.object(compat, "forward_compatible", return_value=True): dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset=dataset, compression=compression) dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id, element_spec=dataset.element_spec) self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
def testFromDatasetIdSharedJobs(self): cluster = data_service_test_base.TestCluster(num_workers=2) datasets = [ dataset_ops.Dataset.range(20, output_type=dtypes.int32), dataset_ops.Dataset.from_tensor_slices(list(range(20, 40))) ] dataset_ids = [ data_service_ops.register_dataset(cluster.dispatcher_address(), ds) for ds in datasets ] # Read from both jobs in parallel, with 2 consumers for each job. data_service_datasets = [] for _ in range(2): for dataset, dataset_id in zip(datasets, dataset_ids): ds = data_service_ops.from_dataset_id( "distributed_epoch", cluster.dispatcher_address(), dataset_id, dataset.element_spec, job_name="shared_job") data_service_datasets.append(ds) ds = dataset_ops.Dataset.from_tensor_slices(data_service_datasets) ds = ds.interleave(lambda x: x, cycle_length=len(data_service_datasets)) self.assertDatasetProduces(ds, list(range(40)), assert_items_equal=True)
def testFromDatasetId(self): num_elements = 10 service = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(service, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", service, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
def testFromDatasetIdNotRegistered(self): dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", dispatcher.target, dataset_id, element_spec) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"): self.evaluate(self.getNext(from_dataset_id_ds)())
def testFromDatasetId(self): dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(dispatcher.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", dispatcher.target, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
def testFromDatasetIdNotRegistered(self): cluster = self.create_cluster(num_workers=1) dataset_id = 0 element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, element_spec) with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"): self.evaluate(self.getNext(from_dataset_id_ds)())
def testFromDatasetId(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
def testFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(), ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
def testElementSpecEagerMode(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(), ds) ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id) self.assertDatasetProduces(ds, list(range(num_elements)))
def testFromDatasetIdWrongElementSpec(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.target, ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, wrong_spec) with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)())
def testElementSpecGraphMode(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(), ds) with self.assertRaisesRegex( ValueError, "In graph mode element_spec must be provided manually."): ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id)
def testFromDatasetIdCardinality(self, dataset_fn, sharding_policy, expected_result): cluster = data_service_test_base.TestCluster(num_workers=2) dataset = dataset_fn() dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset=dataset) dataset = data_service_ops.from_dataset_id( processing_mode=sharding_policy, service=cluster.dispatcher.target, dataset_id=dataset_id, element_spec=dataset.element_spec) self.assertEqual(self.evaluate(dataset.cardinality()), expected_result)
def testFromDatasetIdWrongElementSpec(self): dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(dispatcher.target, ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", dispatcher.target, dataset_id, wrong_spec) with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)())
def testExplicitProtocolFromDatasetId(self): cluster = data_service_test_base.TestCluster(num_workers=1) range_ds = dataset_ops.Dataset.range(10) dataset_id = data_service_ops.register_dataset(cluster.dispatcher.target, range_ds) ds = data_service_ops.from_dataset_id( dataset_id=dataset_id, processing_mode="parallel_epochs", element_spec=range_ds.element_spec, service=cluster.dispatcher.target, data_transfer_protocol="grpc") self.assertDatasetProduces(ds, list(range(10)))
def testRegisteringDatasetAsTfFunction(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) register_func = def_function.function(data_service_ops.register_dataset) dataset_id = register_func( (constant_op.constant("grpc"), constant_op.constant(cluster.dispatcher_address())), ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, ds.element_spec) self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
def testReadDatasetOnDifferentDevices(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 with ops.device(self._devices[0]): dataset = dataset_ops.Dataset.range(num_elements) element_spec = dataset.element_spec dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), dataset) dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher_address(), dataset_id=dataset_id, element_spec=element_spec) self.assertDatasetProduces(dataset, list(range(num_elements))) with ops.device(self._devices[1]): dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher_address(), dataset_id=dataset_id, element_spec=dataset.element_spec) self.assertDatasetProduces(dataset, list(range(num_elements)))
def from_dataset_id(self, processing_mode, cluster, dataset_id, element_spec, job_name=None): return data_service_ops.from_dataset_id( processing_mode, cluster.dispatcher_address(), dataset_id, element_spec, data_transfer_protocol=TRANSFER_PROTOCOL.value, job_name=job_name)
def testFromDatasetIdOmitsElementSpecAndCompression(self, compression): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") dataset = dataset_ops.Dataset.from_tensor_slices( list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset=dataset, compression=compression) dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id) self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
def testFromDatasetIdMultipleComponents(self): dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = data_service_ops.register_dataset(dispatcher.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", dispatcher.target, dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"])
def testFromDatasetIdMultipleComponents(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = data_service_ops.register_dataset(cluster.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"])
def testFromDatasetIdDoesntRequireElementSpec(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False, data_transfer_protocol="grpc") num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(), ds) ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id) self.assertDatasetProduces(ds, list(range(num_elements)))
def testFromDatasetIdDoesntRequireElementSpec(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=data_service_test_base.NO_WORK_DIR, fault_tolerant_mode=False, data_transfer_protocol="grpc") num_elements = 10 dataset = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), dataset) dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher_address(), dataset_id=dataset_id) self.assertDatasetProduces(dataset, list(range(num_elements)))
def testElementSpecGraphMode(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=data_service_test_base.NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 dataset = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), dataset) with self.assertRaisesRegex( ValueError, "In graph mode `element_spec` must be provided manually."): _ = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher_address(), dataset_id=dataset_id)
def testFromDatasetIdOmitsElementSpecAndCompression(self, compression): cluster = data_service_test_base.TestCluster(num_workers=1) dataset = dataset_ops.Dataset.from_tensor_slices( list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) with mock.patch.object(compat, "forward_compatible", return_value=True): dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset=dataset, compression=compression) dataset = data_service_ops.from_dataset_id( processing_mode=ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id) self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
def testElementSpecMixedMode(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) @def_function.function def get_dataset_id(): return data_service_ops.register_dataset(cluster.dispatcher_address(), ds) dataset_id = get_dataset_id() dataset_id_val = tensor_util.constant_value(dataset_id) with self.assertRaisesRegex( ValueError, "Failed to fetch element spec for dataset id " + str(dataset_id_val) + " from tf.data service. If the " "dataset was registered in graph mode or inside a " "tf.function, the `element_spec` must be specified as " "an argument to `from_dataset_id`."): ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id)