def _experimental_distribute_datasets_from_function(self, dataset_fn): input_contexts = [] num_workers = self._input_workers.num_workers for i in range(num_workers): input_contexts.append( distribute_lib.InputContext( num_input_pipelines=num_workers, input_pipeline_id=i, num_replicas_in_sync=self._num_replicas_in_sync)) return input_lib.DistributedDatasetsFromFunction( dataset_fn, self._input_workers, input_contexts, self._container_strategy())
def _experimental_distribute_datasets_from_function(self, dataset_fn): if self._cluster_spec: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) num_input_pipelines = multi_worker_util.worker_count( self._cluster_spec, self._task_type) else: input_pipeline_id = 0 num_input_pipelines = 1 input_context = distribute_lib.InputContext( num_input_pipelines=num_input_pipelines, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) return input_lib.DistributedDatasetsFromFunction( dataset_fn, self._input_workers, [input_context], self._container_strategy())
def _experimental_distribute_datasets_from_function(self, dataset_fn): return input_lib.DistributedDatasetsFromFunction( dataset_fn, self._input_workers, [distribute_lib.InputContext()], self._container_strategy())
def get_distributed_datasets_from_function(dataset_fn, input_workers, input_contexts, strategy, options=None, build=True): """Returns a distributed dataset from the given input function. This is a common function that is used by all strategies to return a distributed dataset. The distributed dataset instance returned is different depending on if we are in a TF 1 or TF 2 context. The distributed dataset instances returned differ from each other in the APIs supported by each of them. Args: dataset_fn: a function that returns a tf.data.Dataset instance. input_workers: an InputWorkers object which specifies devices on which iterators should be created. input_contexts: A list of `InputContext` instances to be passed to call(s) to `dataset_fn`. Length and order should match worker order in `worker_device_pairs`. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. options: Default is None. `tf.distribute.InputOptions` used to control options on how this dataset is distributed. build: whether to build underlying datasets when a `DistributedDatasetFromFunction` is created. This is only useful for `ParameterServerStrategy` now. Returns: A distributed dataset instance. Raises: ValueError: if `options.experimental_replication_mode` and `options.experimental_place_dataset_on_device` are not consistent """ if (options is not None and options.experimental_replication_mode != input_lib.InputReplicationMode.PER_REPLICA and options.experimental_place_dataset_on_device): raise ValueError( "When `experimental_place_dataset_on_device` is set for dataset " "placement, you must also specify `PER_REPLICA` for the " "replication mode") if (options is not None and options.experimental_replication_mode == input_lib.InputReplicationMode.PER_REPLICA and options.experimental_fetch_to_device and options.experimental_place_dataset_on_device): raise ValueError( "`experimental_place_dataset_on_device` can not be set to True " "when experimental_fetch_to_device is True and " "replication mode is set to `PER_REPLICA`") if tf2.enabled(): return input_lib.DistributedDatasetsFromFunction( input_workers, strategy, input_contexts=input_contexts, dataset_fn=dataset_fn, options=options, build=build, ) else: return input_lib_v1.DistributedDatasetsFromFunctionV1( input_workers, strategy, input_contexts, dataset_fn, options)