Пример #1
0
    def _experimental_distribute_datasets_from_function(self, dataset_fn):
        input_contexts = []
        num_workers = self._input_workers.num_workers
        for i in range(num_workers):
            input_contexts.append(
                distribute_lib.InputContext(
                    num_input_pipelines=num_workers,
                    input_pipeline_id=i,
                    num_replicas_in_sync=self._num_replicas_in_sync))

        return input_lib.DistributedDatasetsFromFunction(
            dataset_fn, self._input_workers, input_contexts,
            self._container_strategy())
    def _experimental_distribute_datasets_from_function(self, dataset_fn):
        if self._cluster_spec:
            input_pipeline_id = multi_worker_util.id_in_cluster(
                self._cluster_spec, self._task_type, self._task_id)
            num_input_pipelines = multi_worker_util.worker_count(
                self._cluster_spec, self._task_type)
        else:
            input_pipeline_id = 0
            num_input_pipelines = 1

        input_context = distribute_lib.InputContext(
            num_input_pipelines=num_input_pipelines,
            input_pipeline_id=input_pipeline_id,
            num_replicas_in_sync=self._num_replicas_in_sync)

        return input_lib.DistributedDatasetsFromFunction(
            dataset_fn, self._input_workers, [input_context],
            self._container_strategy())
 def _experimental_distribute_datasets_from_function(self, dataset_fn):
     return input_lib.DistributedDatasetsFromFunction(
         dataset_fn, self._input_workers, [distribute_lib.InputContext()],
         self._container_strategy())
Пример #4
0
def get_distributed_datasets_from_function(dataset_fn,
                                           input_workers,
                                           input_contexts,
                                           strategy,
                                           options=None,
                                           build=True):
    """Returns a distributed dataset from the given input function.

  This is a common function that is used by all strategies to return a
  distributed dataset. The distributed dataset instance returned is different
  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
  instances returned differ from each other in the APIs supported by each of
  them.

  Args:
    dataset_fn: a function that returns a tf.data.Dataset instance.
    input_workers: an InputWorkers object which specifies devices on which
      iterators should be created.
    input_contexts: A list of `InputContext` instances to be passed to call(s)
      to `dataset_fn`. Length and order should match worker order in
      `worker_device_pairs`.
    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
      handle last partial batch.
    options: Default is None. `tf.distribute.InputOptions` used to control
      options on how this dataset is distributed.
    build: whether to build underlying datasets when a
      `DistributedDatasetFromFunction` is created. This is only useful for
      `ParameterServerStrategy` now.

  Returns:
    A distributed dataset instance.

  Raises:
    ValueError: if `options.experimental_replication_mode` and
    `options.experimental_place_dataset_on_device` are not consistent
  """
    if (options is not None and options.experimental_replication_mode !=
            input_lib.InputReplicationMode.PER_REPLICA
            and options.experimental_place_dataset_on_device):
        raise ValueError(
            "When `experimental_place_dataset_on_device` is set for dataset "
            "placement, you must also specify `PER_REPLICA` for the "
            "replication mode")

    if (options is not None and options.experimental_replication_mode
            == input_lib.InputReplicationMode.PER_REPLICA
            and options.experimental_fetch_to_device
            and options.experimental_place_dataset_on_device):
        raise ValueError(
            "`experimental_place_dataset_on_device` can not be set to True "
            "when experimental_fetch_to_device is True and "
            "replication mode is set to `PER_REPLICA`")

    if tf2.enabled():
        return input_lib.DistributedDatasetsFromFunction(
            input_workers,
            strategy,
            input_contexts=input_contexts,
            dataset_fn=dataset_fn,
            options=options,
            build=build,
        )
    else:
        return input_lib_v1.DistributedDatasetsFromFunctionV1(
            input_workers, strategy, input_contexts, dataset_fn, options)