Esempio n. 1
0
 def _make_dataset_iterator(self, dataset):
     """Make iterators for each of the TPU hosts."""
     return input_lib.DatasetIterator(
         dataset,
         self._input_workers,
         self._container_strategy(),
         split_batch_by=self._num_replicas_in_sync)
Esempio n. 2
0
    def _create_iterator(self, input_type, dataset_fn, worker_device_pairs,
                         devices, split_batch_by, enable_get_next_as_optional):
        device_map = values.ReplicaDeviceMap(devices)
        input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)

        if input_type == "input_fn":
            input_contexts = []
            for i in range(input_workers.num_workers):
                input_contexts.append(
                    distribute_lib.InputContext(
                        num_input_pipelines=input_workers.num_workers,
                        input_pipeline_id=i,
                        num_replicas_in_sync=len(devices)))

            iterator = input_lib.InputFunctionIterator(
                dataset_fn,
                input_workers,
                input_contexts,
                _enable_get_next_as_optional=enable_get_next_as_optional)
        else:
            iterator = input_lib.DatasetIterator(
                dataset_fn(distribute_lib.InputContext()),
                input_workers,
                split_batch_by,
                _enable_get_next_as_optional=enable_get_next_as_optional)
        return iterator
 def _make_dataset_iterator(self, dataset):
     """Distributes the dataset to each local GPU."""
     input_context = self._make_input_context()
     return input_lib.DatasetIterator(dataset,
                                      self._input_workers,
                                      self._num_replicas_in_sync,
                                      input_context=input_context)
Esempio n. 4
0
 def _make_dataset_iterator(self, dataset):
     """Make iterators for each of the TPU hosts."""
     input_workers = input_lib.InputWorkers(
         tuple(self._device_input_worker_devices.items()))
     return input_lib.DatasetIterator(
         dataset,
         input_workers,
         self._container_strategy(),
         split_batch_by=self._num_replicas_in_sync)
Esempio n. 5
0
    def _make_dataset_iterator(self, dataset):
        """Make iterator from dataset without splitting the batch.

    This implementation is different than the one in
    `tf.distribute.MirroredStrategy` for purposes of backward compatibility.
    We treat the incoming dataset's batch size as per replica batch size.

    Args:
      dataset: `tf.data.Dataset` for input.
    Returns:
      An `InputIterator` which returns inputs for each step of the computation.
    """
        return input_lib.DatasetIterator(dataset, self._input_workers)
    def _test_iterator(self,
                       input_type,
                       dataset_fn,
                       worker_device_pairs,
                       expected_values,
                       sess=None,
                       split_batch_by=None):
        devices = nest.flatten([ds for _, ds in worker_device_pairs])
        device_map = values.ReplicaDeviceMap(devices)
        input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)

        if input_type == "input_fn":
            input_contexts = [
                distribute_lib.InputContext() for _ in worker_device_pairs
            ]
            input_fn = lambda _: dataset_fn()
            iterator = input_lib.InputFunctionIterator(input_fn, input_workers,
                                                       input_contexts)
        else:
            iterator = input_lib.DatasetIterator(dataset_fn(), input_workers,
                                                 split_batch_by)

        evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)

        evaluate(control_flow_ops.group(iterator.initialize()))

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = evaluate([
                values.select_replica(r, next_element)
                for r in range(len(devices))
            ])
            self.assertAllEqual(expected_value, computed_value)

        with self.assertRaises(errors.OutOfRangeError):
            next_element = iterator.get_next()
            evaluate([
                values.select_replica(r, next_element)
                for r in range(len(devices))
            ])

        # After re-initializing the iterator, should be able to iterate again.
        evaluate(control_flow_ops.group(iterator.initialize()))

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = evaluate([
                values.select_replica(r, next_element)
                for r in range(len(devices))
            ])
            self.assertAllEqual(expected_value, computed_value)
Esempio n. 7
0
    def _wrap_iterator(self,
                       input_type,
                       dataset_fn,
                       input_workers,
                       devices,
                       split_batch_by,
                       enable_get_next_as_optional,
                       strategy,
                       input_context=None):
        # The `input_context` passed in is to shard dataset for
        # MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where
        # multiple InputContexts are needed.
        if input_type == "input_fn":
            self.assertIsNone(
                input_context,
                msg=
                ("`The input_context` arg is only used to shard dataset in "
                 "`MultiWorkerMirroredStrategy` when the input type is dataset."
                 ))

            input_contexts = []
            for i in range(input_workers.num_workers):
                input_contexts.append(
                    distribute_lib.InputContext(
                        # Note: `input_workers.num_workers` is always 1 in between-graph
                        # case.
                        num_input_pipelines=input_workers.num_workers,
                        input_pipeline_id=i,
                        num_replicas_in_sync=len(devices)))

            iterator = input_lib.InputFunctionIterator(
                dataset_fn,
                input_workers,
                input_contexts,
                strategy,
                _enable_get_next_as_optional=enable_get_next_as_optional)
        else:
            iterator = input_lib.DatasetIterator(
                dataset_fn(input_context),
                input_workers,
                strategy,
                split_batch_by=split_batch_by,
                input_context=input_context,
                _enable_get_next_as_optional=enable_get_next_as_optional)
        return iterator
 def _make_dataset_iterator(self, dataset):
   return input_lib.DatasetIterator(
       dataset,
       self._input_workers,
       self._container_strategy(),
       num_replicas_in_sync=self._num_replicas_in_sync)
 def _make_dataset_iterator(self, dataset):
     return input_lib.DatasetIterator(dataset, self._input_workers,
                                      self._num_replicas_in_sync)
 def _make_dataset_iterator(self, dataset):
     """Make iterator from dataset without splitting the batch."""
     # Note that split_batch_by argument is not passed because it is always 1 in
     # this strategy, and adding it adds unnecessary overhead to the dataset.
     return input_lib.DatasetIterator(dataset, self._input_workers,
                                      self._container_strategy())
Esempio n. 11
0
 def _make_dataset_iterator(self, dataset):
     return input_lib.DatasetIterator(dataset, self._input_workers)
Esempio n. 12
0
 def _make_dataset_iterator(self, dataset):
   """Make iterators for each of the TPU hosts."""
   return input_lib.DatasetIterator(dataset, self._input_workers,
                                    self._num_replicas_in_sync)
 def _make_dataset_iterator(self, dataset):
     """Make iterator from dataset without splitting the batch."""
     return input_lib.DatasetIterator(dataset, self._input_workers)
Esempio n. 14
0
 def _make_dataset_iterator(self, dataset):
     return input_lib.DatasetIterator(dataset, self._input_workers,
                                      self._container_strategy())
Esempio n. 15
0
 def _make_dataset_iterator(self, dataset):
     """Make iterators for each of the TPU hosts."""
     return input_lib.DatasetIterator(dataset,
                                      self._input_workers,
                                      self._num_replicas_in_sync,
                                      _enable_get_next_as_optional=True)