Exemplo n.º 1
0
  def testCompression(self, element):
    element = element._obj

    compressed = compression_ops.compress(element)
    uncompressed = compression_ops.uncompress(
        compressed, structure.type_spec_from_value(element))
    self.assertValuesEqual(element, self.evaluate(uncompressed))
Exemplo n.º 2
0
def _register_dataset(service, dataset, compression):
    """Registers a dataset with the tf.data service.

  This transformation is similar to `register_dataset`, but supports additional
  parameters which we do not yet want to add to the public Python API.

  Args:
    service: A string or a tuple indicating how to connect to the tf.data
      service. If it's a string, it should be in the format
      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
      address and `<protocol>` can optionally be used to override the default
      protocol to use. If it's a tuple, it should be (protocol, address).
    dataset: A `tf.data.Dataset` to register with the tf.data service.
    compression: How to compress the dataset's elements before transferring them
      over the network. "AUTO" leaves the decision of how to compress up to the
      tf.data service runtime. `None` indicates not to compress.

  Returns:
    A scalar int64 tensor of the registered dataset's id.
  """
    valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE]
    if compression not in valid_compressions:
        raise ValueError(
            "Invalid compression argument: {}. Must be one of {}".format(
                compression, valid_compressions))
    if isinstance(service, tuple):
        protocol, address = service
    else:
        protocol, address = _parse_service(service)
    external_state_policy = dataset.options(
    ).experimental_external_state_policy
    if external_state_policy is None:
        external_state_policy = ExternalStatePolicy.WARN

    encoded_spec = ""
    if context.executing_eagerly():
        coder = nested_structure_coder.StructureCoder()
        encoded_spec = coder.encode_structure(
            dataset.element_spec).SerializeToString()

    if compression == COMPRESSION_AUTO:
        dataset = dataset.map(lambda *x: compression_ops.compress(x),
                              num_parallel_calls=dataset_ops.AUTOTUNE)
    else:
        # TODO (damien-aymon) Make this cleaner
        # EASL - we set this because we look out for the first "map" operation to
        # insert our caching ops.
        dataset = dataset.apply(service_cache_mark("noop"))

    dataset = dataset.prefetch(dataset_ops.AUTOTUNE)
    dataset = dataset._apply_debug_options()  # pylint: disable=protected-access

    dataset_id = gen_experimental_dataset_ops.register_dataset(
        dataset._variant_tensor,  # pylint: disable=protected-access
        address=address,
        protocol=protocol,
        external_state_policy=external_state_policy.value,
        element_spec=encoded_spec)

    return dataset_id
Exemplo n.º 3
0
 def testCompressionOutputDTypeMismatch(self):
   element = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
   compressed = compression_ops.compress(element)
   with self.assertRaisesRegex(errors.FailedPreconditionError,
                               "but got a tensor of type string"):
     uncompressed = compression_ops.uncompress(
         compressed, structure.type_spec_from_value(0))
     self.evaluate(uncompressed)
Exemplo n.º 4
0
  def testDatasetCompression(self, element):
    element = element._obj

    dataset = dataset_ops.Dataset.from_tensors(element)
    element_spec = dataset.element_spec

    dataset = dataset.map(lambda *x: compression_ops.compress(x))
    dataset = dataset.map(lambda x: compression_ops.uncompress(x, element_spec))
    self.assertDatasetProduces(dataset, [element])
Exemplo n.º 5
0
 def testCompressionInputShapeMismatch(self):
   element = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
   compressed = compression_ops.compress(element)
   compressed = [compressed, compressed]
   error = (
       errors.InvalidArgumentError
       if context.executing_eagerly() else ValueError)
   with self.assertRaises(error):
     uncompressed = compression_ops.uncompress(
         compressed, structure.type_spec_from_value(0))
     self.evaluate(uncompressed)
    def _apply_fn(dataset):  # pylint: disable=missing-docstring
        external_state_policy = dataset.options(
        ).experimental_external_state_policy
        if external_state_policy is None:
            external_state_policy = ExternalStatePolicy.WARN

        uncompressed_spec = dataset.element_spec
        # Compress the dataset elements to reduce the amount of data that needs to
        # be sent over the network.
        # TODO(b/157105111): Make this an autotuned parallel map when we have a way
        # to limit memory usage.
        dataset = dataset.map(lambda *x: compression_ops.compress(x))
        # Prefetch one compressed element to reduce latency when requesting data
        # from tf.data workers.
        # TODO(b/157105111): Set this to autotune when we have a way to limit
        # memory usage
        dataset = dataset.prefetch(1)
        # Apply options so that the dataset executed in the tf.data service will
        # be optimized and support autotuning.
        dataset = dataset._apply_options()  # pylint: disable=protected-access
        dataset_id = gen_experimental_dataset_ops.register_dataset(
            dataset._variant_tensor,  # pylint: disable=protected-access
            address=address,
            protocol=protocol,
            external_state_policy=external_state_policy.value)
        dataset = _DataServiceDataset(
            input_dataset=dataset,
            dataset_id=dataset_id,
            processing_mode=processing_mode,
            address=address,
            protocol=protocol,
            job_name=job_name,
            max_outstanding_requests=max_outstanding_requests,
            task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
        # TODO(b/157105111): Make this an autotuned parallel map when we have a way
        # to limit memory usage.
        # The value 16 is chosen based on experience with pipelines that require
        # more than 8 parallel calls to prevent this stage from being a bottleneck.
        dataset = dataset.map(lambda x: compression_ops.uncompress(
            x, output_spec=uncompressed_spec),
                              num_parallel_calls=16)

        # Disable autosharding for shared jobs.
        if job_name:
            options = dataset_ops.Options()
            options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
            dataset = dataset.with_options(options)
        return dataset
Exemplo n.º 7
0
def _register_dataset(service, dataset, compression):
    """Registers a dataset with the tf.data service.

  This transformation is similar to `register_dataset`, but supports additional
  parameters which we do not yet want to add to the public Python API.

  Args:
    service: A string indicating how to connect to the tf.data service. The
      string should be in the format `[<protocol>://]<address>`, where
      `<address>` identifies the dispatcher address and `<protocol>` can
      optionally be used to override the default protocol to use.
    dataset: A `tf.data.Dataset` to register with the tf.data service.
    compression: How to compress the dataset's elements before transferring them
      over the network. "AUTO" leaves the decision of how to compress up to the
      tf.data service runtime. `None` indicates not to compress.

  Returns:
    A scalar int64 tensor of the registered dataset's id.
  """
    valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE]
    if compression not in valid_compressions:
        raise ValueError(
            "Invalid compression argument: {}. Must be one of {}".format(
                compression, valid_compressions))
    protocol, address = _parse_service(service)
    external_state_policy = dataset.options(
    ).experimental_external_state_policy
    if external_state_policy is None:
        external_state_policy = ExternalStatePolicy.WARN

    if compression == COMPRESSION_AUTO:
        dataset = dataset.map(lambda *x: compression_ops.compress(x),
                              num_parallel_calls=dataset_ops.AUTOTUNE)
    dataset = dataset.prefetch(dataset_ops.AUTOTUNE)
    # Apply options so that the dataset executed in the tf.data service will
    # be optimized and support autotuning.
    # TODO(b/183497230): Move options application after deserialization.
    dataset = dataset._apply_options()  # pylint: disable=protected-access

    dataset_id = gen_experimental_dataset_ops.register_dataset(
        dataset._variant_tensor,  # pylint: disable=protected-access
        address=address,
        protocol=protocol,
        external_state_policy=external_state_policy.value)

    return dataset_id
Exemplo n.º 8
0
def register_dataset(service, dataset):
    """Registers a dataset with the tf.data service.

  `register_dataset` registers a dataset with the tf.data service so that
  datasets can be created later with
  `tf.data.experimental.service.from_dataset_id`. This is useful when the
  dataset
  is registered by one process, then used in another process. When the same
  process is both registering and reading from the dataset, it is simpler to use
  `tf.data.experimental.service.distribute` instead.

  If the dataset is already registered with the tf.data service,
  `register_dataset` returns the already-registered dataset's id.

  >>> dispatcher = tf.data.experimental.service.DispatchServer()
  >>> dispatcher_address = dispatcher.target.split("://")[1]
  >>> worker = tf.data.experimental.service.WorkerServer(
  ...     tf.data.experimental.service.WorkerConfig(
  ...         dispatcher_address=dispatcher_address))
  >>> dataset = tf.data.Dataset.range(10)
  >>> dataset_id = tf.data.experimental.service.register_dataset(
  ...     dispatcher.target, dataset)
  >>> dataset = tf.data.experimental.service.from_dataset_id(
  ...     processing_mode="parallel_epochs",
  ...     service=dispatcher.target,
  ...     dataset_id=dataset_id,
  ...     element_spec=dataset.element_spec)
  >>> print(list(dataset.as_numpy_iterator()))
  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

  Args:
    service: A string indicating how to connect to the tf.data service. The
      string should be in the format "protocol://address", e.g.
      "grpc://localhost:5000".
    dataset: A `tf.data.Dataset` to register with the tf.data service.

  Returns:
    A scalar int64 tensor of the registered dataset's id.
  """
    protocol, address = _parse_service(service)
    external_state_policy = dataset.options(
    ).experimental_external_state_policy
    if external_state_policy is None:
        external_state_policy = ExternalStatePolicy.WARN

    # Compress the dataset elements to reduce the amount of data that needs to
    # be sent over the network.
    dataset = dataset.map(lambda *x: compression_ops.compress(x),
                          num_parallel_calls=dataset_ops.AUTOTUNE)
    # Prefetch one compressed element to reduce latency when requesting data
    # from tf.data workers.
    # TODO(b/157105111): Set this to autotune when we have a way to limit
    # memory usage
    dataset = dataset.prefetch(1)
    # Apply options so that the dataset executed in the tf.data service will
    # be optimized and support autotuning.
    dataset = dataset._apply_options()  # pylint: disable=protected-access

    dataset_id = gen_experimental_dataset_ops.register_dataset(
        dataset._variant_tensor,  # pylint: disable=protected-access
        address=address,
        protocol=protocol,
        external_state_policy=external_state_policy.value)

    return dataset_id