Exemplo n.º 1
0
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     return values.PerReplicaDataset(
         self._call_dataset_fn(input_fn, distribute_lib.InputContext()),
         [self._device])
Exemplo n.º 2
0
 def _make_input_fn_iterator(
     self,
     input_fn,
     replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
   return values.InputFunctionIterator(
       input_fn, [("/job:localhost", [self._device])],
       [distribute_lib.InputContext()])
Exemplo n.º 3
0
 def testPerReplicaBatchSize(self):
     input_context = distribute_lib.InputContext(num_input_pipelines=2,
                                                 input_pipeline_id=1,
                                                 num_replicas_in_sync=6)
     self.assertEqual(2, input_context.get_per_replica_batch_size(12))
     with self.assertRaises(ValueError):
         input_context.get_per_replica_batch_size(13)
Exemplo n.º 4
0
 def testProperties(self):
     input_context = distribute_lib.InputContext(num_input_pipelines=2,
                                                 input_pipeline_id=1,
                                                 num_replicas_in_sync=6)
     self.assertEqual(6, input_context.num_replicas_in_sync)
     self.assertEqual(1, input_context.input_pipeline_id)
     self.assertEqual(2, input_context.num_input_pipelines)
Exemplo n.º 5
0
    def _test_iterator(self,
                       input_fn,
                       worker_device_pairs,
                       expected_values,
                       sess=None):
        devices = nest.flatten([ds for _, ds in worker_device_pairs])
        input_contexts = [
            distribute_lib.InputContext() for _ in worker_device_pairs
        ]
        iterator = values.InputFunctionIterator(input_fn, worker_device_pairs,
                                                input_contexts)

        evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)

        evaluate(iterator.initialize())

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = evaluate(
                [values.select_device(d, next_element) for d in devices])
            self.assertEqual(expected_value, computed_value)

        with self.assertRaises(errors.OutOfRangeError):
            next_element = iterator.get_next()
            evaluate([values.select_device(d, next_element) for d in devices])

        # After re-initializing the iterator, should be able to iterate again.
        evaluate(iterator.initialize())

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = evaluate(
                [values.select_device(d, next_element) for d in devices])
            self.assertEqual(expected_value, computed_value)
    def _make_input_fn_iterator(
            self,
            input_fn,
            replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
        if self._cluster_spec:
            input_fns = []
            for i in range(len(self._worker_devices)):
                input_context = distribute_lib.InputContext(
                    num_input_pipelines=len(self._worker_devices),
                    input_pipeline_id=i,
                    num_replicas_in_sync=self.num_replicas_in_sync)
                input_fns.append(
                    partial(self._call_dataset_fn, input_fn, input_context))

            return values.MultiWorkerDataset(input_fns, self._worker_devices,
                                             self._auto_shard_dataset)
        else:
            input_context = distribute_lib.InputContext(
                num_input_pipelines=1,
                input_pipeline_id=0,
                num_replicas_in_sync=self.num_replicas_in_sync)
            return values.PerReplicaDataset(
                self._call_dataset_fn(input_fn, input_context), self._devices)
Exemplo n.º 7
0
 def _make_input_fn_iterator(
     self,
     input_fn,
     replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
   """Distributes the dataset to each local GPU."""
   if self._cluster_spec is None:
     input_pipeline_id = 0
   else:
     input_pipeline_id = multi_worker_util.id_in_cluster(
         self._cluster_spec, self._task_type, self._task_id)
   input_context = distribute_lib.InputContext(
       num_input_pipelines=self._num_workers,
       input_pipeline_id=input_pipeline_id,
       num_replicas_in_sync=self._num_replicas_in_sync)
   return values.PerReplicaDataset(
       self._call_dataset_fn(input_fn, input_context), self._devices, True)
Exemplo n.º 8
0
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     input_contexts = []
     if self._cluster_spec:
         num_workers = len(self._worker_devices)
         worker_device_pairs = self._worker_devices
     else:
         num_workers = 1
         worker_device_pairs = [("/job:localhost", self._devices)]
     for i in range(num_workers):
         input_contexts.append(
             distribute_lib.InputContext(
                 num_input_pipelines=num_workers,
                 input_pipeline_id=i,
                 num_replicas_in_sync=self._num_replicas_in_sync))
     return values.InputFunctionIterator(input_fn, worker_device_pairs,
                                         input_contexts)
Exemplo n.º 9
0
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     """Distributes the dataset to each local GPU."""
     if self._cluster_spec:
         input_pipeline_id = multi_worker_util.id_in_cluster(
             self._cluster_spec, self._task_type, self._task_id)
         num_input_pipelines = multi_worker_util.worker_count(
             self._cluster_spec, self._task_type)
     else:
         input_pipeline_id = 0
         num_input_pipelines = 1
     input_context = distribute_lib.InputContext(
         num_input_pipelines=num_input_pipelines,
         input_pipeline_id=input_pipeline_id,
         num_replicas_in_sync=self._num_replicas_in_sync)
     worker_device_pairs = [(self._worker_device, self._compute_devices)]
     return values.InputFunctionIterator(input_fn, worker_device_pairs,
                                         [input_context])