def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): return input_lib_v1.InputFunctionIterator( input_fn, self._input_workers, [distribute_lib.InputContext()], self._container_strategy())
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): """Distributes the input function to each local GPU.""" input_context = self._make_input_context() return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, [input_context], self._container_strategy())
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): input_contexts = [] num_workers = self._input_workers.num_workers for i in range(num_workers): input_contexts.append( distribute_lib.InputContext( num_input_pipelines=num_workers, input_pipeline_id=i, num_replicas_in_sync=self._num_replicas_in_sync)) return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, input_contexts, self._container_strategy())
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): """Distributes the dataset to each local GPU.""" if self._cluster_spec: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) num_input_pipelines = multi_worker_util.worker_count( self._cluster_spec, self._task_type) else: input_pipeline_id = 0 num_input_pipelines = 1 input_context = distribute_lib.InputContext( num_input_pipelines=num_input_pipelines, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, [input_context], self._container_strategy())