def testCompression(self, element): element = element._obj compressed = compression_ops.compress(element) uncompressed = compression_ops.uncompress( compressed, structure.type_spec_from_value(element)) self.assertValuesEqual(element, self.evaluate(uncompressed))
def _register_dataset(service, dataset, compression): """Registers a dataset with the tf.data service. This transformation is similar to `register_dataset`, but supports additional parameters which we do not yet want to add to the public Python API. Args: service: A string or a tuple indicating how to connect to the tf.data service. If it's a string, it should be in the format `[<protocol>://]<address>`, where `<address>` identifies the dispatcher address and `<protocol>` can optionally be used to override the default protocol to use. If it's a tuple, it should be (protocol, address). dataset: A `tf.data.Dataset` to register with the tf.data service. compression: How to compress the dataset's elements before transferring them over the network. "AUTO" leaves the decision of how to compress up to the tf.data service runtime. `None` indicates not to compress. Returns: A scalar int64 tensor of the registered dataset's id. """ valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE] if compression not in valid_compressions: raise ValueError( "Invalid compression argument: {}. Must be one of {}".format( compression, valid_compressions)) if isinstance(service, tuple): protocol, address = service else: protocol, address = _parse_service(service) external_state_policy = dataset.options( ).experimental_external_state_policy if external_state_policy is None: external_state_policy = ExternalStatePolicy.WARN encoded_spec = "" if context.executing_eagerly(): coder = nested_structure_coder.StructureCoder() encoded_spec = coder.encode_structure( dataset.element_spec).SerializeToString() if compression == COMPRESSION_AUTO: dataset = dataset.map(lambda *x: compression_ops.compress(x), num_parallel_calls=dataset_ops.AUTOTUNE) else: # TODO (damien-aymon) Make this cleaner # EASL - we set this because we look out for the first "map" operation to # insert our caching ops. dataset = dataset.apply(service_cache_mark("noop")) dataset = dataset.prefetch(dataset_ops.AUTOTUNE) dataset = dataset._apply_debug_options() # pylint: disable=protected-access dataset_id = gen_experimental_dataset_ops.register_dataset( dataset._variant_tensor, # pylint: disable=protected-access address=address, protocol=protocol, external_state_policy=external_state_policy.value, element_spec=encoded_spec) return dataset_id
def testCompressionOutputDTypeMismatch(self): element = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" compressed = compression_ops.compress(element) with self.assertRaisesRegex(errors.FailedPreconditionError, "but got a tensor of type string"): uncompressed = compression_ops.uncompress( compressed, structure.type_spec_from_value(0)) self.evaluate(uncompressed)
def testDatasetCompression(self, element): element = element._obj dataset = dataset_ops.Dataset.from_tensors(element) element_spec = dataset.element_spec dataset = dataset.map(lambda *x: compression_ops.compress(x)) dataset = dataset.map(lambda x: compression_ops.uncompress(x, element_spec)) self.assertDatasetProduces(dataset, [element])
def testCompressionInputShapeMismatch(self): element = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" compressed = compression_ops.compress(element) compressed = [compressed, compressed] error = ( errors.InvalidArgumentError if context.executing_eagerly() else ValueError) with self.assertRaises(error): uncompressed = compression_ops.uncompress( compressed, structure.type_spec_from_value(0)) self.evaluate(uncompressed)
def _apply_fn(dataset): # pylint: disable=missing-docstring external_state_policy = dataset.options( ).experimental_external_state_policy if external_state_policy is None: external_state_policy = ExternalStatePolicy.WARN uncompressed_spec = dataset.element_spec # Compress the dataset elements to reduce the amount of data that needs to # be sent over the network. # TODO(b/157105111): Make this an autotuned parallel map when we have a way # to limit memory usage. dataset = dataset.map(lambda *x: compression_ops.compress(x)) # Prefetch one compressed element to reduce latency when requesting data # from tf.data workers. # TODO(b/157105111): Set this to autotune when we have a way to limit # memory usage dataset = dataset.prefetch(1) # Apply options so that the dataset executed in the tf.data service will # be optimized and support autotuning. dataset = dataset._apply_options() # pylint: disable=protected-access dataset_id = gen_experimental_dataset_ops.register_dataset( dataset._variant_tensor, # pylint: disable=protected-access address=address, protocol=protocol, external_state_policy=external_state_policy.value) dataset = _DataServiceDataset( input_dataset=dataset, dataset_id=dataset_id, processing_mode=processing_mode, address=address, protocol=protocol, job_name=job_name, max_outstanding_requests=max_outstanding_requests, task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) # TODO(b/157105111): Make this an autotuned parallel map when we have a way # to limit memory usage. # The value 16 is chosen based on experience with pipelines that require # more than 8 parallel calls to prevent this stage from being a bottleneck. dataset = dataset.map(lambda x: compression_ops.uncompress( x, output_spec=uncompressed_spec), num_parallel_calls=16) # Disable autosharding for shared jobs. if job_name: options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF dataset = dataset.with_options(options) return dataset
def _register_dataset(service, dataset, compression): """Registers a dataset with the tf.data service. This transformation is similar to `register_dataset`, but supports additional parameters which we do not yet want to add to the public Python API. Args: service: A string indicating how to connect to the tf.data service. The string should be in the format `[<protocol>://]<address>`, where `<address>` identifies the dispatcher address and `<protocol>` can optionally be used to override the default protocol to use. dataset: A `tf.data.Dataset` to register with the tf.data service. compression: How to compress the dataset's elements before transferring them over the network. "AUTO" leaves the decision of how to compress up to the tf.data service runtime. `None` indicates not to compress. Returns: A scalar int64 tensor of the registered dataset's id. """ valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE] if compression not in valid_compressions: raise ValueError( "Invalid compression argument: {}. Must be one of {}".format( compression, valid_compressions)) protocol, address = _parse_service(service) external_state_policy = dataset.options( ).experimental_external_state_policy if external_state_policy is None: external_state_policy = ExternalStatePolicy.WARN if compression == COMPRESSION_AUTO: dataset = dataset.map(lambda *x: compression_ops.compress(x), num_parallel_calls=dataset_ops.AUTOTUNE) dataset = dataset.prefetch(dataset_ops.AUTOTUNE) # Apply options so that the dataset executed in the tf.data service will # be optimized and support autotuning. # TODO(b/183497230): Move options application after deserialization. dataset = dataset._apply_options() # pylint: disable=protected-access dataset_id = gen_experimental_dataset_ops.register_dataset( dataset._variant_tensor, # pylint: disable=protected-access address=address, protocol=protocol, external_state_policy=external_state_policy.value) return dataset_id
def register_dataset(service, dataset): """Registers a dataset with the tf.data service. `register_dataset` registers a dataset with the tf.data service so that datasets can be created later with `tf.data.experimental.service.from_dataset_id`. This is useful when the dataset is registered by one process, then used in another process. When the same process is both registering and reading from the dataset, it is simpler to use `tf.data.experimental.service.distribute` instead. If the dataset is already registered with the tf.data service, `register_dataset` returns the already-registered dataset's id. >>> dispatcher = tf.data.experimental.service.DispatchServer() >>> dispatcher_address = dispatcher.target.split("://")[1] >>> worker = tf.data.experimental.service.WorkerServer( ... tf.data.experimental.service.WorkerConfig( ... dispatcher_address=dispatcher_address)) >>> dataset = tf.data.Dataset.range(10) >>> dataset_id = tf.data.experimental.service.register_dataset( ... dispatcher.target, dataset) >>> dataset = tf.data.experimental.service.from_dataset_id( ... processing_mode="parallel_epochs", ... service=dispatcher.target, ... dataset_id=dataset_id, ... element_spec=dataset.element_spec) >>> print(list(dataset.as_numpy_iterator())) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] Args: service: A string indicating how to connect to the tf.data service. The string should be in the format "protocol://address", e.g. "grpc://localhost:5000". dataset: A `tf.data.Dataset` to register with the tf.data service. Returns: A scalar int64 tensor of the registered dataset's id. """ protocol, address = _parse_service(service) external_state_policy = dataset.options( ).experimental_external_state_policy if external_state_policy is None: external_state_policy = ExternalStatePolicy.WARN # Compress the dataset elements to reduce the amount of data that needs to # be sent over the network. dataset = dataset.map(lambda *x: compression_ops.compress(x), num_parallel_calls=dataset_ops.AUTOTUNE) # Prefetch one compressed element to reduce latency when requesting data # from tf.data workers. # TODO(b/157105111): Set this to autotune when we have a way to limit # memory usage dataset = dataset.prefetch(1) # Apply options so that the dataset executed in the tf.data service will # be optimized and support autotuning. dataset = dataset._apply_options() # pylint: disable=protected-access dataset_id = gen_experimental_dataset_ops.register_dataset( dataset._variant_tensor, # pylint: disable=protected-access address=address, protocol=protocol, external_state_policy=external_state_policy.value) return dataset_id