def testFromDatasetIdMultipleComponents(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = data_service_ops.register_dataset(cluster.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.target, dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"])
def testFromDatasetIdWrongElementSpec(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", cluster.dispatcher_address(), dataset_id, wrong_spec) with self.assertRaisesRegex(errors.FailedPreconditionError, "Expected a tensor of type variant"): self.evaluate(self.getNext(from_dataset_id_ds)())
def testFromDatasetIdMultipleComponents(self): dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds}) dataset_id = data_service_ops.register_dataset(dispatcher.target, ds) from_dataset_id_ds = data_service_ops.from_dataset_id( "parallel_epochs", dispatcher.target, dataset_id, ds.element_spec) output = self.getDatasetOutput(from_dataset_id_ds) for i in range(num_elements): self.assertEqual(i, output[i]["a"][0]) self.assertEqual(i, output[i]["a"][1]) self.assertEqual(i, output[i]["b"])
def testElementSpecGraphMode(self): cluster = data_service_test_base.TestCluster(num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), ds) with self.assertRaisesRegex( ValueError, "In graph mode element_spec must be provided manually."): ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id)
def testFromDatasetIdOmitsElementSpecAndCompression(self, compression): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") dataset = dataset_ops.Dataset.from_tensor_slices( list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset=dataset, compression=compression) dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id) self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
def testFromDatasetIdDoesntRequireElementSpec(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False, data_transfer_protocol="grpc") num_elements = 10 ds = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(), ds) ds = data_service_ops.from_dataset_id("parallel_epochs", cluster.dispatcher_address(), dataset_id) self.assertDatasetProduces(ds, list(range(num_elements)))
def testElementSpecGraphMode(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=data_service_test_base.NO_WORK_DIR, fault_tolerant_mode=False) num_elements = 10 dataset = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), dataset) with self.assertRaisesRegex( ValueError, "In graph mode `element_spec` must be provided manually."): _ = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher_address(), dataset_id=dataset_id)
def testFromDatasetIdDoesntRequireElementSpec(self): cluster = data_service_test_base.TestCluster( num_workers=1, work_dir=data_service_test_base.NO_WORK_DIR, fault_tolerant_mode=False, data_transfer_protocol="grpc") num_elements = 10 dataset = dataset_ops.Dataset.range(num_elements) dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), dataset) dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher_address(), dataset_id=dataset_id) self.assertDatasetProduces(dataset, list(range(num_elements)))
def testFromDatasetIdOmitsElementSpecAndCompression(self, compression): cluster = data_service_test_base.TestCluster(num_workers=1) dataset = dataset_ops.Dataset.from_tensor_slices( list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) with mock.patch.object(compat, "forward_compatible", return_value=True): dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset=dataset, compression=compression) dataset = data_service_ops.from_dataset_id( processing_mode=ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id) self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
def testReadDatasetOnDifferentDevices(self): cluster = data_service_test_base.TestCluster(num_workers=1) num_elements = 10 with ops.device(self._devices[0]): dataset = dataset_ops.Dataset.range(num_elements) element_spec = dataset.element_spec dataset_id = data_service_ops.register_dataset( cluster.dispatcher_address(), dataset) dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher_address(), dataset_id=dataset_id, element_spec=element_spec) self.assertDatasetProduces(dataset, list(range(num_elements))) with ops.device(self._devices[1]): dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher_address(), dataset_id=dataset_id, element_spec=dataset.element_spec) self.assertDatasetProduces(dataset, list(range(num_elements)))
def testFromDatasetIdOmitsCompression(self, compression): cluster = data_service_test_base.TestCluster( num_workers=1, data_transfer_protocol="grpc") dataset = dataset_ops.Dataset.from_tensor_slices( list("abcdefghijklmnopqrstuvwxyz")) def to_upper(x): return script_ops.numpy_function( func=lambda x: x.decode("utf-8").upper(), inp=[x], Tout=dtypes.string) dataset = dataset.map(to_upper, num_parallel_calls=dataset_ops.AUTOTUNE) dataset_id = data_service_ops.register_dataset( cluster.dispatcher.target, dataset=dataset, compression=compression) dataset = data_service_ops.from_dataset_id( processing_mode=data_service_ops.ShardingPolicy.OFF, service=cluster.dispatcher.target, dataset_id=dataset_id, element_spec=dataset.element_spec) self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
def get_dataset_id(): return data_service_ops.register_dataset( cluster.dispatcher_address(), dataset)