Beispiel #1
0
    def _experimental_distribute_datasets_from_function(self, dataset_fn):
        input_contexts = []
        num_workers = self._input_workers.num_workers
        for i in range(num_workers):
            input_contexts.append(
                distribute_lib.InputContext(
                    num_input_pipelines=num_workers,
                    input_pipeline_id=i,
                    num_replicas_in_sync=self._num_replicas_in_sync))

        return input_lib.get_distributed_datasets_from_function(
            dataset_fn, self._input_workers, input_contexts,
            self._container_strategy())
 def _distribute_datasets_from_function(self, dataset_fn, options):
   if (options and options.experimental_replication_mode ==
       distribute_lib.InputReplicationMode.PER_REPLICA):
     raise NotImplementedError(
         "InputReplicationMode.PER_REPLICA "
         "is only supported in "
         "`experimental_distribute_datasets_from_function` "
         "of tf.distribute.MirroredStrategy")
   return input_lib.get_distributed_datasets_from_function(
       dataset_fn,
       self._input_workers_with_options(options),
       [distribute_lib.InputContext()],
       self._container_strategy())
Beispiel #3
0
    def _experimental_distribute_datasets_from_function(self, dataset_fn):
        if self._cluster_spec:
            input_pipeline_id = multi_worker_util.id_in_cluster(
                self._cluster_spec, self._task_type, self._task_id)
            num_input_pipelines = multi_worker_util.worker_count(
                self._cluster_spec, self._task_type)
        else:
            input_pipeline_id = 0
            num_input_pipelines = 1

        input_context = distribute_lib.InputContext(
            num_input_pipelines=num_input_pipelines,
            input_pipeline_id=input_pipeline_id,
            num_replicas_in_sync=self._num_replicas_in_sync)

        return input_lib.get_distributed_datasets_from_function(
            dataset_fn, self._input_workers, [input_context],
            self._container_strategy())
    def _experimental_distribute_datasets_from_function(
            self, dataset_fn, options):
        input_workers = self._get_input_workers(options)
        input_contexts = []
        num_workers = input_workers.num_workers
        for i in range(num_workers):
            input_contexts.append(
                distribute_lib.InputContext(
                    num_input_pipelines=num_workers,
                    input_pipeline_id=i,
                    num_replicas_in_sync=self._num_replicas_in_sync))

        distributed_dataset = input_lib.get_distributed_datasets_from_function(
            dataset_fn, input_workers, input_contexts,
            self._container_strategy())

        # We can only check after the dataset_fn is called.
        if options is None or options.experimental_prefetch_to_device:
            self._check_spec(distributed_dataset.element_spec)
        return distributed_dataset
    def _distribute_datasets_from_function(self, dataset_fn, options):
        # There is no synchronization beyond a worker and thus, the number of
        # input pipelines in sync is only 1 per worker.
        input_pipeline_id_in_sync = 0
        num_input_pipelines_in_sync = 1

        input_context = distribute_lib.InputContext(
            num_input_pipelines=num_input_pipelines_in_sync,
            input_pipeline_id=input_pipeline_id_in_sync,
            num_replicas_in_sync=self._num_replicas_in_sync)

        # If this DistributedDatasetFromFunction is created outside
        # ClusterCoordinator, i,e, outside a tf.function, we don't build its
        # underlying datasets immediately until it is passed to
        # ClusterCoordinator.create_per_worker_dataset.
        return input_lib.get_distributed_datasets_from_function(
            dataset_fn,
            self._input_workers_with_options(options), [input_context],
            self._container_strategy(),
            options=options,
            build=ops.inside_function())  # will be built by ClusterCoordinator
    def _distribute_datasets_from_function(self, dataset_fn, options):
        self._assert_used_with_cluster_coordinator()
        if not ops.get_default_graph().building_function:
            raise ValueError(
                "The `distribute_datasets_from_function` method must be called "
                "inside a `tf.function` passed to `create_per_worker_dataset` of "
                "`tf.distribute.experimental.coordinator.ClusterCoordinator`")

        # There is no synchronization beyond a worker and thus, the number of
        # input pipelines in sync is only 1 per worker.
        input_pipeline_id_in_sync = 0
        num_input_pipelines_in_sync = 1

        input_context = distribute_lib.InputContext(
            num_input_pipelines=num_input_pipelines_in_sync,
            input_pipeline_id=input_pipeline_id_in_sync,
            num_replicas_in_sync=self._num_replicas_in_sync)

        return input_lib.get_distributed_datasets_from_function(
            dataset_fn,
            self._input_workers_with_options(options), [input_context],
            self._container_strategy(),
            options=options)
Beispiel #7
0
 def _distribute_datasets_from_function(self, dataset_fn, options):
     return input_lib.get_distributed_datasets_from_function(
         dataset_fn, self._input_workers_with_options(options),
         [distribute_lib.InputContext()], self._container_strategy())
Beispiel #8
0
 def _experimental_distribute_datasets_from_function(self, dataset_fn):
     return input_lib.get_distributed_datasets_from_function(
         dataset_fn, self._input_workers, [distribute_lib.InputContext()],
         self._container_strategy())