Пример #1
0
 def _eager_reset(self):
     """Resets the MultiDeviceIterator in eager mode."""
     if not context.executing_eagerly():
         raise ValueError("Eager reset is only supported in eager mode.")
     # pylint: disable=protected-access
     self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
         self._dataset._variant_tensor,
         self._multi_device_iterator_resource,
         max_buffer_size=self._max_buffer_size)
     for i, device in enumerate(self._devices):
         with ops.device(device):
             ds = self._create_device_dataset(i)
             # Reset the device iterator resources with the new dataset.
             ds_variant = ds._variant_tensor
             gen_dataset_ops.make_iterator(
                 ds_variant, self._device_iterators[i]._iterator_resource)
 def _eager_reset(self):
   """Resets the MultiDeviceIterator in eager mode."""
   if not context.executing_eagerly():
     raise ValueError("Eager reset is only supported in eager mode.")
   # pylint: disable=protected-access
   self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
       self._dataset._variant_tensor,
       self._multi_device_iterator_resource,
       max_buffer_size=self._max_buffer_size)
   for i, device in enumerate(self._devices):
     with ops.device(device):
       ds = self._create_device_dataset(i)
       # Reset the device iterator resources with the new dataset.
       ds_variant = ds._variant_tensor
       gen_dataset_ops.make_iterator(
           ds_variant, self._device_iterators[i]._iterator_resource)
Пример #3
0
 def _eager_reset(self):
   """Resets the MultiDeviceIterator in eager mode."""
   if not ops.executing_eagerly_outside_functions():
     raise ValueError(
         "Resetting a multi-device iterator is only supported in the eager "
         "mode.")
   # pylint: disable=protected-access
   self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
       self._dataset._variant_tensor,
       self._multi_device_iterator_resource,
       max_buffer_size=self._max_buffer_size)
   for i, device in enumerate(self._devices):
     with ops.device(device):
       ds = _create_device_dataset(self._prototype_device_datasets[i],
                                   self._incarnation_id,
                                   self._prefetch_buffer_size,
                                   self._experimental_slack)
       # Reset the device iterator resources with the new dataset.
       ds_variant = ds._variant_tensor
       gen_dataset_ops.make_iterator(
           ds_variant, self._device_iterators[i]._iterator_resource)
Пример #4
0
    def __init__(self,
                 dataset,
                 devices,
                 max_buffer_size=1,
                 prefetch_buffer_size=1,
                 source_device="/cpu:0"):
        """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device
        to prefetch into.
      source_device: The host device to place the `dataset` on.

    Raises:
      RuntimeError: If run in Eager mode.
    """
        if context.executing_eagerly():
            # TODO(rohanj): Fix this. Tracking bug: b/116467184
            raise RuntimeError(
                "MultiDeviceIterator is not currently supported in "
                "Eager mode.")
        self._dataset = dataset
        self._devices = devices
        self._source_device = source_device
        self._source_device_tensor = ops.convert_to_tensor(source_device)

        self._flat_output_shapes = nest.flatten(
            sparse.as_dense_shapes(self._dataset.output_shapes,
                                   self._dataset.output_classes))
        self._flat_output_types = nest.flatten(
            sparse.as_dense_types(self._dataset.output_types,
                                  self._dataset.output_classes))

        # Create the MultiDeviceIterator.
        with ops.device(self._source_device):
            self._multi_device_iterator_resource = (
                gen_dataset_ops.multi_device_iterator(
                    devices=self._devices,
                    shared_name="",
                    container="",
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes))

            # The incarnation ID is used to ensure consistency between the per-device
            # iterators and the multi-device iterator.
            self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
                self._dataset._as_variant_tensor(),  # pylint: disable=protected-access
                self._multi_device_iterator_resource,
                max_buffer_size=max_buffer_size)

        # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
        # initialize the device side of the pipeline. This would allow the
        # MultiDeviceIterator to choose, for example, to move some transformations
        # into the device side from its input. It might be useful in rewriting.
        # Create the per device iterators.
        self._device_iterators = []
        i = 0
        for device in self._devices:
            ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
                                     self._incarnation_id,
                                     self._source_device_tensor, device,
                                     self._dataset.output_shapes,
                                     self._dataset.output_types,
                                     self._dataset.output_classes)
            if prefetch_buffer_size > 0:
                ds = ds.prefetch(prefetch_buffer_size)
            with ops.device(device):
                self._device_iterators.append(ds.make_initializable_iterator())
            i += 1

        device_iterator_initializers = [
            iterator.initializer for iterator in self._device_iterators
        ]
        self._initializer = control_flow_ops.group(
            *device_iterator_initializers)
Пример #5
0
    def __init__(self,
                 dataset=None,
                 devices=None,
                 max_buffer_size=1,
                 prefetch_buffer_size=1,
                 source_device="/cpu:0",
                 components=None,
                 element_spec=None):
        """Constructs a MultiDeviceIteratorV2 object.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device to
        prefetch into.
      source_device: The host device to place the `dataset` on.  In order to
        prevent deadlocks, if the prefetch_buffer_size is greater than the
        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
      components: Tensor components to construct the MultiDeviceIterator from.
      element_spec: A nested structure of `TypeSpec` objects that
        represents the type specification of elements of the iterator.

    Raises:
      RuntimeError: If executed in graph mode or outside of function building
      mode.
    """
        if (not context.executing_eagerly()
                and not ops.get_default_graph()._building_function):  # pylint: disable=protected-access
            raise RuntimeError(
                "MultiDeviceIteratorV2 is only supported inside of "
                "tf.function or when eager execution is enabled.")
        if devices is None:
            raise ValueError("`devices` must be provided")
        error_message = "Either `dataset` or both `components` and "
        "`element_spec` need to be provided."

        if dataset is None:
            if (components is None or element_spec is None):
                raise ValueError(error_message)
            self._element_spec = element_spec
            self._devices = devices
            self._source_device = source_device
            self._multi_device_iterator_resource = components[0]
            self._deleter = components[1]
            self._device_iterators = components[2:]
            iterator_handles = []
            for it in self._device_iterators:
                iterator_handles.append(it._iterator_resource)  # pylint: disable=protected-access
        else:
            if (components is not None or element_spec is not None):
                raise ValueError(error_message)
            options = dataset_ops.Options()
            options.experimental_distribute.num_devices = len(devices)
            dataset = dataset.with_options(options)
            dataset = dataset._apply_options()  # pylint: disable=protected-access
            self._element_spec = dataset.element_spec
            experimental_slack = dataset.options().experimental_slack
            self._devices = devices
            self._source_device = source_device
            source_device_tensor = ops.convert_to_tensor(self._source_device)

            if prefetch_buffer_size > max_buffer_size:
                max_buffer_size = prefetch_buffer_size

            # Create the MultiDeviceIterator.
            with ops.device(self._source_device):
                self._multi_device_iterator_resource, self._deleter = (
                    gen_dataset_ops.anonymous_multi_device_iterator(
                        devices=self._devices, **dataset._flat_structure))  # pylint: disable=protected-access

                # The incarnation ID is used to ensure consistency between the
                # per-device iterators and the multi-device iterator.
                incarnation_id = gen_dataset_ops.multi_device_iterator_init(
                    dataset._variant_tensor,  # pylint: disable=protected-access
                    self._multi_device_iterator_resource,
                    max_buffer_size=max_buffer_size)

            prototype_device_datasets = []
            for i, device in enumerate(self._devices):
                with ops.device(device):
                    ds = _PerDeviceGenerator(
                        i, self._multi_device_iterator_resource,
                        incarnation_id, source_device_tensor,
                        dataset.element_spec)
                    prototype_device_datasets.append(ds)

            # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
            # initialize the device side of the pipeline. This would allow the
            # MultiDeviceIterator to choose, for example, to move some transformations
            # into the device side from its input. It might be useful in rewriting.
            # Create the per device iterators.
            self._device_iterators = []
            iterator_handles = []
            for i, device in enumerate(self._devices):
                with ops.device(device):
                    ds = _create_device_dataset(prototype_device_datasets[i],
                                                incarnation_id,
                                                prefetch_buffer_size,
                                                experimental_slack)
                    iterator = iter(ds)
                    self._device_iterators.append(iterator)
                    iterator_handles.append(iterator._iterator_resource)  # pylint: disable=protected-access

        self._resource_deleter = MultiDeviceIteratorResourceDeleter(
            multi_device_iterator=self._multi_device_iterator_resource,
            iterators=iterator_handles,
            device=self._source_device,
            deleter=self._deleter)
Пример #6
0
    def __init__(self,
                 dataset,
                 devices,
                 max_buffer_size=1,
                 prefetch_buffer_size=1,
                 source_device="/cpu:0"):
        """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device to
        prefetch into.
      source_device: The host device to place the `dataset` on.  In order to
        prevent deadlocks, if the prefetch_buffer_size is greater than the
        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
    """
        options = dataset_ops.Options()
        options.experimental_distribute.num_devices = len(devices)
        dataset = dataset.with_options(options)
        self._dataset = dataset._apply_options()  # pylint: disable=protected-access
        self._experimental_slack = dataset.options().experimental_slack
        self._devices = devices
        self._source_device = source_device
        self._source_device_tensor = ops.convert_to_tensor(source_device)
        self._max_buffer_size = max_buffer_size
        self._prefetch_buffer_size = prefetch_buffer_size

        if self._prefetch_buffer_size > self._max_buffer_size:
            self._max_buffer_size = self._prefetch_buffer_size

        # Create the MultiDeviceIterator.
        with ops.device(self._source_device):
            # TODO(b/121378567): Get rid of this shared_name hack.
            shared_name = ""
            if context.executing_eagerly():
                shared_name = context.shared_name()
            self._multi_device_iterator_resource = (
                gen_dataset_ops.multi_device_iterator(
                    devices=self._devices,
                    shared_name=shared_name,
                    container="",
                    **self._dataset._flat_structure))  # pylint: disable=protected-access
            if context.executing_eagerly():
                # Delete the resource when this object is deleted
                self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
                    handle=self._multi_device_iterator_resource,
                    handle_device=self._source_device)

            # The incarnation ID is used to ensure consistency between the per-device
            # iterators and the multi-device iterator.
            self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
                self._dataset._variant_tensor,  # pylint: disable=protected-access
                self._multi_device_iterator_resource,
                max_buffer_size=self._max_buffer_size)

        self._prototype_device_datasets = []
        for i, device in enumerate(self._devices):
            with ops.device(device):
                ds = _PerDeviceGenerator(i,
                                         self._multi_device_iterator_resource,
                                         self._incarnation_id,
                                         self._source_device_tensor,
                                         self._dataset.element_spec)
                self._prototype_device_datasets.append(ds)

        # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
        # initialize the device side of the pipeline. This would allow the
        # MultiDeviceIterator to choose, for example, to move some transformations
        # into the device side from its input. It might be useful in rewriting.
        # Create the per device iterators.
        self._device_iterators = []
        for i, device in enumerate(self._devices):
            with ops.device(device):
                ds = _create_device_dataset(self._prototype_device_datasets[i],
                                            self._incarnation_id,
                                            self._prefetch_buffer_size,
                                            self._experimental_slack)
                if context.executing_eagerly():
                    self._device_iterators.append(
                        dataset_ops.make_one_shot_iterator(ds))
                else:
                    self._device_iterators.append(
                        dataset_ops.make_initializable_iterator(ds))

        if not context.executing_eagerly():
            device_iterator_initializers = [
                iterator.initializer for iterator in self._device_iterators
            ]
            self._initializer = control_flow_ops.group(
                *device_iterator_initializers)
Пример #7
0
    def __init__(self,
                 dataset,
                 devices,
                 max_buffer_size=1,
                 prefetch_buffer_size=1,
                 source_device="/cpu:0"):
        """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device
        to prefetch into.
      source_device: The host device to place the `dataset` on.

    Raises:
      RuntimeError: If run in Eager mode.
    """
        self._dataset = dataset._apply_options()  # pylint: disable=protected-access
        self._devices = devices
        self._source_device = source_device
        self._source_device_tensor = ops.convert_to_tensor(source_device)

        # Create the MultiDeviceIterator.
        with ops.device(self._source_device):
            # TODO(b/121378567): Get rid of this shared_name hack.
            shared_name = ""
            if context.executing_eagerly():
                shared_name = context.shared_name()
            self._multi_device_iterator_resource = (
                gen_dataset_ops.multi_device_iterator(
                    devices=self._devices,
                    shared_name=shared_name,
                    container="",
                    **dataset_ops.flat_structure(dataset)))
            if context.executing_eagerly():
                # Delete the resource when this object is deleted
                self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
                    handle=self._multi_device_iterator_resource,
                    handle_device=self._source_device)

            # The incarnation ID is used to ensure consistency between the per-device
            # iterators and the multi-device iterator.
            self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
                self._dataset._variant_tensor,  # pylint: disable=protected-access
                self._multi_device_iterator_resource,
                max_buffer_size=max_buffer_size)

        # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
        # initialize the device side of the pipeline. This would allow the
        # MultiDeviceIterator to choose, for example, to move some transformations
        # into the device side from its input. It might be useful in rewriting.
        # Create the per device iterators.
        self._device_iterators = []
        for i, device in enumerate(self._devices):
            with ops.device(device):
                ds = _PerDeviceGenerator(i,
                                         self._multi_device_iterator_resource,
                                         self._incarnation_id,
                                         self._source_device_tensor,
                                         dataset._element_structure)  # pylint: disable=protected-access
                if prefetch_buffer_size > 0:
                    ds = ds.prefetch(prefetch_buffer_size)
                # TODO(jsimsa): Enable auto-tuning and optimizations when supported for
                # non-CPU devices.
                options = dataset_ops.Options()
                options.experimental_autotune = False
                options.experimental_optimization.apply_default_optimizations = False
                ds = ds.with_options(options)
                if context.executing_eagerly():
                    self._device_iterators.append(
                        dataset_ops.make_one_shot_iterator(ds))
                else:
                    self._device_iterators.append(
                        dataset_ops.make_initializable_iterator(ds))

        if not context.executing_eagerly():
            device_iterator_initializers = [
                iterator.initializer for iterator in self._device_iterators
            ]
            self._initializer = control_flow_ops.group(
                *device_iterator_initializers)
  def __init__(self,
               dataset,
               devices,
               max_buffer_size=1,
               prefetch_buffer_size=1,
               source_device="/cpu:0"):
    """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device
        to prefetch into.
      source_device: The host device to place the `dataset` on.

    Raises:
      RuntimeError: If run in Eager mode.
    """
    if context.executing_eagerly():
      # TODO(rohanj): Fix this. Tracking bug: b/116467184
      raise RuntimeError("MultiDeviceIterator is not currently supported in "
                         "Eager mode.")
    self._dataset = dataset._apply_options()  # pylint: disable=protected-access
    self._devices = devices
    self._source_device = source_device
    self._source_device_tensor = ops.convert_to_tensor(source_device)

    self._flat_output_shapes = nest.flatten(
        sparse.as_dense_shapes(self._dataset.output_shapes,
                               self._dataset.output_classes))
    self._flat_output_types = nest.flatten(
        sparse.as_dense_types(self._dataset.output_types,
                              self._dataset.output_classes))

    # Create the MultiDeviceIterator.
    with ops.device(self._source_device):
      self._multi_device_iterator_resource = (
          gen_dataset_ops.multi_device_iterator(
              devices=self._devices,
              shared_name="",
              container="",
              output_types=self._flat_output_types,
              output_shapes=self._flat_output_shapes))

      # The incarnation ID is used to ensure consistency between the per-device
      # iterators and the multi-device iterator.
      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
          self._dataset._as_variant_tensor(),  # pylint: disable=protected-access
          self._multi_device_iterator_resource,
          max_buffer_size=max_buffer_size)

    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
    # initialize the device side of the pipeline. This would allow the
    # MultiDeviceIterator to choose, for example, to move some transformations
    # into the device side from its input. It might be useful in rewriting.
    # Create the per device iterators.
    self._device_iterators = []
    i = 0
    for device in self._devices:
      ds = _PerDeviceGenerator(
          i, self._multi_device_iterator_resource, self._incarnation_id,
          self._source_device_tensor, device, self._dataset.output_shapes,
          self._dataset.output_types, self._dataset.output_classes)
      if prefetch_buffer_size > 0:
        ds = ds.prefetch(prefetch_buffer_size)
      with ops.device(device):
        self._device_iterators.append(ds.make_initializable_iterator())
      i += 1

    device_iterator_initializers = [
        iterator.initializer for iterator in self._device_iterators
    ]
    self._initializer = control_flow_ops.group(*device_iterator_initializers)
  def __init__(self,
               dataset,
               devices,
               max_buffer_size=1,
               prefetch_buffer_size=1,
               source_device="/cpu:0"):
    """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device
        to prefetch into.
      source_device: The host device to place the `dataset` on.

    Raises:
      RuntimeError: If run in Eager mode.
    """
    self._dataset = dataset._apply_options()  # pylint: disable=protected-access
    self._devices = devices
    self._source_device = source_device
    self._source_device_tensor = ops.convert_to_tensor(source_device)

    # Create the MultiDeviceIterator.
    with ops.device(self._source_device):
      # TODO(b/121378567): Get rid of this shared_name hack.
      shared_name = ""
      if context.executing_eagerly():
        # Ensure a unique name when eager execution is enabled to avoid spurious
        # sharing issues.
        shared_name += str(ops.uid())
      self._multi_device_iterator_resource = (
          gen_dataset_ops.multi_device_iterator(
              devices=self._devices,
              shared_name=shared_name,
              container="",
              **dataset_ops.flat_structure(dataset)))
      if context.executing_eagerly():
        # Delete the resource when this object is deleted
        self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
            handle=self._multi_device_iterator_resource,
            handle_device=self._source_device)

      # The incarnation ID is used to ensure consistency between the per-device
      # iterators and the multi-device iterator.
      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
          self._dataset._variant_tensor,  # pylint: disable=protected-access
          self._multi_device_iterator_resource,
          max_buffer_size=max_buffer_size)

    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
    # initialize the device side of the pipeline. This would allow the
    # MultiDeviceIterator to choose, for example, to move some transformations
    # into the device side from its input. It might be useful in rewriting.
    # Create the per device iterators.
    self._device_iterators = []
    for i, device in enumerate(self._devices):
      with ops.device(device):
        ds = _PerDeviceGenerator(
            i, self._multi_device_iterator_resource, self._incarnation_id,
            self._source_device_tensor, dataset._element_structure)  # pylint: disable=protected-access
        if prefetch_buffer_size > 0:
          ds = ds.prefetch(prefetch_buffer_size)
        # TODO(jsimsa): Enable auto-tuning and optimizations when supported for
        # non-CPU devices.
        options = dataset_ops.Options()
        options.experimental_autotune = False
        options.experimental_optimization.apply_default_optimizations = False
        ds = ds.with_options(options)
        if context.executing_eagerly():
          self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds))
        else:
          self._device_iterators.append(
              dataset_ops.make_initializable_iterator(ds))

    if not context.executing_eagerly():
      device_iterator_initializers = [
          iterator.initializer for iterator in self._device_iterators
      ]
      self._initializer = control_flow_ops.group(*device_iterator_initializers)
  def __init__(self,
               dataset,
               devices,
               max_buffer_size=1,
               prefetch_buffer_size=1,
               source_device="/cpu:0"):
    """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device
        to prefetch into.
      source_device: The host device to place the `dataset` on.

      In order to prevent deadlocks, if the prefetch_buffer_size is greater
      than the max_buffer_size, we set the max_buffer_size to
      prefetch_buffer_size.
    """
    options = dataset_ops.Options()
    options.experimental_distribute.num_devices = len(devices)
    dataset = dataset.with_options(options)
    self._dataset = dataset._apply_options()  # pylint: disable=protected-access
    self._experimental_slack = dataset.options().experimental_slack
    self._devices = devices
    self._source_device = source_device
    self._source_device_tensor = ops.convert_to_tensor(source_device)
    self._max_buffer_size = max_buffer_size
    self._prefetch_buffer_size = prefetch_buffer_size

    if self._prefetch_buffer_size > self._max_buffer_size:
      self._max_buffer_size = self._prefetch_buffer_size

    # Create the MultiDeviceIterator.
    with ops.device(self._source_device):
      # TODO(b/121378567): Get rid of this shared_name hack.
      shared_name = ""
      if context.executing_eagerly():
        shared_name = context.shared_name()
      self._multi_device_iterator_resource = (
          gen_dataset_ops.multi_device_iterator(
              devices=self._devices,
              shared_name=shared_name,
              container="",
              **dataset_ops.flat_structure(self._dataset)))
      if context.executing_eagerly():
        # Delete the resource when this object is deleted
        self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
            handle=self._multi_device_iterator_resource,
            handle_device=self._source_device)

      # The incarnation ID is used to ensure consistency between the per-device
      # iterators and the multi-device iterator.
      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
          self._dataset._variant_tensor,  # pylint: disable=protected-access
          self._multi_device_iterator_resource,
          max_buffer_size=self._max_buffer_size)

    self._prototype_device_datasets = []
    for i, device in enumerate(self._devices):
      with ops.device(device):
        ds = _PerDeviceGenerator(
            i, self._multi_device_iterator_resource, self._incarnation_id,
            self._source_device_tensor, self._dataset._element_structure)  # pylint: disable=protected-access
        self._prototype_device_datasets.append(ds)

    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
    # initialize the device side of the pipeline. This would allow the
    # MultiDeviceIterator to choose, for example, to move some transformations
    # into the device side from its input. It might be useful in rewriting.
    # Create the per device iterators.
    self._device_iterators = []
    for i, device in enumerate(self._devices):
      with ops.device(device):
        ds = self._create_device_dataset(i)
        if context.executing_eagerly():
          self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds))
        else:
          self._device_iterators.append(
              dataset_ops.make_initializable_iterator(ds))

    if not context.executing_eagerly():
      device_iterator_initializers = [
          iterator.initializer for iterator in self._device_iterators
      ]
      self._initializer = control_flow_ops.group(*device_iterator_initializers)