Ejemplo n.º 1
0
  def __init__(self,
               input_dataset,
               dataset_id,
               processing_mode,
               address,
               protocol,
               job_name=None,
               max_outstanding_requests=None,
               task_refresh_interval_hint_ms=None):
    """Constructs a _DataServiceDatasetV2.

    Args:
      input_dataset: The input dataset, which should be registered with the
        tf.data service under `dataset_id`.
      dataset_id: The dataset id for the dataset to read from.
      processing_mode: A string specifying the policy for how data should be
        processed by tf.data workers. Currently, the only supported value is
        "parallel_epochs".
      address: The tf.data service address, e.g. "localhost:5000".
      protocol: The protocol to use for communicating with the tf.data service,
        e.g. "grpc".
      job_name: (Optional.) The name of the job. This argument makes it
        possible for multiple datasets to share the same job. The default
        behavior is that the dataset creates anonymous, exclusively owned jobs.
      max_outstanding_requests: (Optional.) A limit on how many elements may be
        requested at the same time. You can use this option to control the
        amount of memory used, since `distribute` won't use more than
        `element_size` * `max_outstanding_requests` of memory.
      task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
        the master for task changes.
    """

    if job_name is None:
      job_name = ""
    if max_outstanding_requests is None:
      max_outstanding_requests = dataset_ops.AUTOTUNE
    if task_refresh_interval_hint_ms is None:
      task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE

    self._element_spec = input_dataset.element_spec

    variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
        dataset_id=dataset_id,
        processing_mode=processing_mode,
        address=address,
        protocol=protocol,
        job_name=job_name,
        max_outstanding_requests=max_outstanding_requests,
        task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
        iteration_counter=gen_experimental_dataset_ops.dummy_iteration_counter(
        ),
        **self._flat_structure)
    super(_DataServiceDatasetV2, self).__init__(variant_tensor)
Ejemplo n.º 2
0
    def __init__(self,
                 input_dataset,
                 dataset_id,
                 address,
                 protocol,
                 max_outstanding_requests=None,
                 task_refresh_interval_hint_ms=None):
        """Constructs a _DataServiceDatasetV2.

    Args:
      input_dataset: The input dataset, which should be registered with the
        tf.data service under `dataset_id`.
      dataset_id: The dataset id for the dataset to read from.
      address: The tf.data service address, e.g. "localhost:5000".
      protocol: The protocol to use for communicating with the tf.data service,
        e.g. "grpc".
      max_outstanding_requests: (Optional.) A limit on how many elements may be
        requested at the same time. You can use this option to control the
        amount of memory used, since `distribute` won't use more than
        `element_size` * `max_outstanding_requests` of memory.
      task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
        the master for task changes.
    """

        if max_outstanding_requests is None:
            max_outstanding_requests = dataset_ops.AUTOTUNE
        if task_refresh_interval_hint_ms is None:
            task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE

        self._element_spec = input_dataset.element_spec
        self._dataset_id = dataset_id
        self._address = address
        self._protocol = protocol
        self._max_outstanding_requests = max_outstanding_requests
        self._task_refresh_interval_hint_ms = task_refresh_interval_hint_ms

        variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
            address=address,
            protocol=protocol,
            max_outstanding_requests=max_outstanding_requests,
            task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
            **self._flat_structure)
        super(_DataServiceDatasetV2, self).__init__(variant_tensor)
Ejemplo n.º 3
0
    def __init__(self,
                 dataset_id,
                 processing_mode,
                 address,
                 protocol,
                 job_name=None,
                 max_outstanding_requests=None,
                 task_refresh_interval_hint_ms=None):
        """Constructs a _DataServiceDatasetV2.

    Args:
      dataset_id: The dataset id for the dataset to read from.
      processing_mode: A string specifying the policy for how data should be
        processed by tf.data workers. Currently, the only supported value is
        "parallel_epochs".
      address: The tf.data service address, e.g. "localhost:5000".
      protocol: The protocol to use for communicating with the tf.data service,
        e.g. "grpc".
      job_name: (Optional.) The name of the job. This argument makes it possible
        for multiple datasets to share the same job. The default behavior is
        that the dataset creates anonymous, exclusively owned jobs.
      max_outstanding_requests: (Optional.) A limit on how many elements may be
        requested at the same time. You can use this option to control the
        amount of memory used, since `distribute` won't use more than
        `element_size` * `max_outstanding_requests` of memory.
      task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
        the dispatcher for task changes.
    """

        if job_name is None:
            job_name = ""
        if max_outstanding_requests is None:
            max_outstanding_requests = dataset_ops.AUTOTUNE
        if task_refresh_interval_hint_ms is None:
            task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE

        self._dataset_id = ops.convert_to_tensor(dataset_id,
                                                 dtype=dtypes.int64,
                                                 name="dataset_id")
        self._processing_mode = ops.convert_to_tensor(processing_mode,
                                                      dtype=dtypes.string,
                                                      name="processing_mode")
        self._address = ops.convert_to_tensor(address,
                                              dtype=dtypes.string,
                                              name="address")
        self._protocol = ops.convert_to_tensor(protocol,
                                               dtype=dtypes.string,
                                               name="protocol")
        self._job_name = ops.convert_to_tensor(job_name,
                                               dtype=dtypes.string,
                                               name="job_name")
        self._max_outstanding_requests = ops.convert_to_tensor(
            max_outstanding_requests,
            dtype=dtypes.int64,
            name="max_outstanding_requests")
        # Datasets executed by the tf.data service produce compressed elements
        # represented by scalar DT_VARIANTs.
        self._element_spec = tensor_spec.TensorSpec(shape=(),
                                                    dtype=dtypes.variant)

        variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
            dataset_id=self._dataset_id,
            processing_mode=self._processing_mode,
            address=self._address,
            protocol=self._protocol,
            job_name=self._job_name,
            max_outstanding_requests=self._max_outstanding_requests,
            task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
            iteration_counter=gen_experimental_dataset_ops.
            dummy_iteration_counter(),
            **self._flat_structure)
        super(_DataServiceDatasetV2, self).__init__(variant_tensor)
Ejemplo n.º 4
0
  def __init__(self,
               dataset_id,
               processing_mode,
               address,
               protocol,
               job_name=None,
               consumer_index=None,
               num_consumers=None,
               max_outstanding_requests=None,
               task_refresh_interval_hint_ms=None):
    """Constructs a _DataServiceDatasetV2.

    Args:
      dataset_id: The dataset id for the dataset to read from.
      processing_mode: A string specifying the policy for how data should be
        processed by tf.data workers. Can be either "parallel_epochs" to have
        each tf.data worker process a copy of the dataset, or
        "distributed_epoch" to split a single iteration of the dataset across
        all the workers.
      address: The tf.data service address, e.g. "localhost:5000".
      protocol: The protocol to use for communicating with the tf.data service,
        e.g. "grpc".
      job_name: (Optional.) The name of the job. This argument makes it possible
        for multiple datasets to share the same job. The default behavior is
        that the dataset creates anonymous, exclusively owned jobs.
      consumer_index: (Optional.) The index of the consumer in the range from
        `0` to `num_consumers`. Must be specified alongside `num_consumers`.
        When specified, consumers will read from the job in a strict round-robin
        order, instead of the default first-come-first-served order.
      num_consumers: (Optional.) The number of consumers which will consume from
        the job. Must be specified alongside `consumer_index`. When specified,
        consumers will read from the job in a strict round-robin order, instead
        of the default first-come-first-served order. When `num_consumers` is
        specified, the dataset must have infinite cardinality to prevent a
        producer from running out of data early and causing consumers to go out
        of sync.
      max_outstanding_requests: (Optional.) A limit on how many elements may be
        requested at the same time. You can use this option to control the
        amount of memory used, since `distribute` won't use more than
        `element_size` * `max_outstanding_requests` of memory.
      task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
        the dispatcher for task changes.
    """
    if consumer_index is None != num_consumers is None:
      raise ValueError(
          "Must either set both consumer_index and num_consumers, or neither. ",
          "consumer_index: ", consumer_index, ", num_consumers: ",
          num_consumers)
    if num_consumers is not None and job_name is None:
      raise ValueError("job_name must be set when setting num_consumers")

    if job_name is None:
      job_name = ""
    if max_outstanding_requests is None:
      max_outstanding_requests = dataset_ops.AUTOTUNE
    if task_refresh_interval_hint_ms is None:
      task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
    if consumer_index is None:
      consumer_index = -1
    if num_consumers is None:
      num_consumers = -1

    self._dataset_id = ops.convert_to_tensor(
        dataset_id, dtype=dtypes.int64, name="dataset_id")
    self._processing_mode = ops.convert_to_tensor(
        processing_mode, dtype=dtypes.string, name="processing_mode")
    self._address = ops.convert_to_tensor(
        address, dtype=dtypes.string, name="address")
    self._protocol = ops.convert_to_tensor(
        protocol, dtype=dtypes.string, name="protocol")
    self._job_name = ops.convert_to_tensor(
        job_name, dtype=dtypes.string, name="job_name")
    self._consumer_index = ops.convert_to_tensor(
        consumer_index, dtype=dtypes.int64, name="consumer_index")
    self._num_consumers = ops.convert_to_tensor(
        num_consumers, dtype=dtypes.int64, name="num_consumers")
    self._max_outstanding_requests = ops.convert_to_tensor(
        max_outstanding_requests,
        dtype=dtypes.int64,
        name="max_outstanding_requests")
    # Datasets executed by the tf.data service produce compressed elements
    # represented by scalar DT_VARIANTs.
    self._element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)

    if num_consumers >= 0:
      variant_tensor = gen_experimental_dataset_ops.data_service_dataset_v2(
          dataset_id=self._dataset_id,
          processing_mode=self._processing_mode,
          address=self._address,
          protocol=self._protocol,
          job_name=self._job_name,
          consumer_index=self._consumer_index,
          num_consumers=self._num_consumers,
          max_outstanding_requests=self._max_outstanding_requests,
          task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
          iteration_counter=gen_experimental_dataset_ops
          .dummy_iteration_counter(),
          **self._flat_structure)
    else:
      variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
          dataset_id=self._dataset_id,
          processing_mode=self._processing_mode,
          address=self._address,
          protocol=self._protocol,
          job_name=self._job_name,
          max_outstanding_requests=self._max_outstanding_requests,
          task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
          iteration_counter=gen_experimental_dataset_ops
          .dummy_iteration_counter(),
          **self._flat_structure)
    super(_DataServiceDatasetV2, self).__init__(variant_tensor)