示例#1
0
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     return input_lib_v1.InputFunctionIterator(
         input_fn, self._input_workers, [distribute_lib.InputContext()],
         self._container_strategy())
示例#2
0
 def _make_input_fn_iterator(
     self,
     input_fn,
     replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
   """Distributes the input function to each local GPU."""
   input_context = self._make_input_context()
   return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
                                             [input_context],
                                             self._container_strategy())
示例#3
0
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     input_contexts = []
     num_workers = self._input_workers.num_workers
     for i in range(num_workers):
         input_contexts.append(
             distribute_lib.InputContext(
                 num_input_pipelines=num_workers,
                 input_pipeline_id=i,
                 num_replicas_in_sync=self._num_replicas_in_sync))
     return input_lib_v1.InputFunctionIterator(input_fn,
                                               self._input_workers,
                                               input_contexts,
                                               self._container_strategy())
示例#4
0
 def _make_input_fn_iterator(
     self,
     input_fn,
     replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
   """Distributes the dataset to each local GPU."""
   if self._cluster_spec:
     input_pipeline_id = multi_worker_util.id_in_cluster(
         self._cluster_spec, self._task_type, self._task_id)
     num_input_pipelines = multi_worker_util.worker_count(
         self._cluster_spec, self._task_type)
   else:
     input_pipeline_id = 0
     num_input_pipelines = 1
   input_context = distribute_lib.InputContext(
       num_input_pipelines=num_input_pipelines,
       input_pipeline_id=input_pipeline_id,
       num_replicas_in_sync=self._num_replicas_in_sync)
   return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
                                             [input_context],
                                             self._container_strategy())