def __init__(self, dataset):
        """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.contrib.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Args:
      dataset: A `tf.contrib.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

        if not context.in_eager_mode():
            raise RuntimeError(
                "{} objects only make sense when eager execution is enabled".
                format(type(self)))
        ds_variant = dataset.make_dataset_resource()
        self._output_types = dataset.output_types
        self._flat_output_types = nest.flatten(dataset.output_types)
        self._flat_output_shapes = nest.flatten(dataset.output_shapes)
        self._resource = gen_dataset_ops.iterator(
            container="",
            shared_name=_iterator_shared_name(),
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)
        gen_dataset_ops.make_iterator(ds_variant, self._resource)
Esempio n. 2
0
    def _create_iterator(self, dataset):
        # pylint: disable=protected-access
        dataset = dataset._apply_options()

        # Store dataset reference to ensure that dataset is alive when this iterator
        # is being used. For example, `tf.data.Dataset.from_generator` registers
        # a few py_funcs that are needed in `self._next_internal`.  If the dataset
        # is deleted, this iterator crashes on `self.__next__(...)` call.
        self._dataset = dataset

        ds_variant = dataset._variant_tensor
        self._element_spec = dataset.element_spec
        self._flat_output_types = structure.get_flat_tensor_types(
            self._element_spec)
        self._flat_output_shapes = structure.get_flat_tensor_shapes(
            self._element_spec)
        with ops.colocate_with(ds_variant):
            self._iterator_resource, self._deleter = (
                gen_dataset_ops.anonymous_iterator_v2(
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes))
            gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
            # Delete the resource when this object is deleted
            self._resource_deleter = IteratorResourceDeleter(
                handle=self._iterator_resource,
                device=self._device,
                deleter=self._deleter)
Esempio n. 3
0
  def __init__(self, dataset):
    """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.contrib.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Args:
      dataset: A `tf.contrib.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

    if not context.in_eager_mode():
      raise RuntimeError(
          "{} objects only make sense when eager execution is enabled".format(
              type(self)))
    ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
    self._output_types = dataset.output_types
    self._flat_output_types = nest.flatten(dataset.output_types)
    self._flat_output_shapes = nest.flatten(dataset.output_shapes)
    self._resource = gen_dataset_ops.iterator(
        container="",
        shared_name=_iterator_shared_name(),
        output_types=self._flat_output_types,
        output_shapes=self._flat_output_shapes)
    gen_dataset_ops.make_iterator(ds_variant, self._resource)
Esempio n. 4
0
    def _create_iterator(self, dataset):
        # pylint: disable=protected-access
        dataset = dataset._apply_debug_options()

        # Store dataset reference to ensure that dataset is alive when this iterator
        # is being used. For example, `tf.data.Dataset.from_generator` registers
        # a few py_funcs that are needed in `self._next_internal`.  If the dataset
        # is deleted, this iterator crashes on `self.__next__(...)` call.
        self._dataset = dataset

        ds_variant = dataset._variant_tensor
        self._element_spec = dataset.element_spec
        self._flat_output_types = structure.get_flat_tensor_types(
            self._element_spec)
        self._flat_output_shapes = structure.get_flat_tensor_shapes(
            self._element_spec)
        with ops.colocate_with(ds_variant):
            self._iterator_resource = (gen_dataset_ops.anonymous_iterator_v3(
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes))
            if not context.executing_eagerly():
                # Add full type information to the graph so host memory types inside
                # variants stay on CPU, e.g, ragged string tensors.
                # TODO(b/224776031) Remove this when AnonymousIterateV3 can use
                # (reverse) type inference and all other ops that are needed to
                # provide type information to the AnonymousIterateV3 also support
                # type inference (esp. cross-function type inference) instead of
                # setting the full type information manually.
                fulltype = type_utils.iterator_full_type_from_spec(
                    self._element_spec)
                # fulltype is PRODUCT[ITERATOR[PRODUCT[...]]]
                assert len(fulltype.args[0].args[0].args) == len(
                    self._flat_output_types)
                self._iterator_resource.op.experimental_set_type(fulltype)
            gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
Esempio n. 5
0
 def _eager_reset(self):
     """Resets the MultiDeviceIterator in eager mode."""
     if not context.executing_eagerly():
         raise ValueError("Eager reset is only supported in eager mode.")
     # pylint: disable=protected-access
     self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
         self._dataset._variant_tensor,
         self._multi_device_iterator_resource,
         max_buffer_size=self._max_buffer_size)
     for i, device in enumerate(self._devices):
         with ops.device(device):
             ds = self._create_device_dataset(i)
             # Reset the device iterator resources with the new dataset.
             ds_variant = ds._variant_tensor
             gen_dataset_ops.make_iterator(
                 ds_variant, self._device_iterators[i]._iterator_resource)
Esempio n. 6
0
  def make_initializer(self, dataset, name=None):
    """Returns a `tf.Operation` that initializes this iterator on `dataset`.

    Args:
      dataset: A `Dataset` with compatible structure to this iterator.
      name: (Optional.) A name for the created operation.

    Returns:
      A `tf.Operation` that can be run to initialize this iterator on the given
      `dataset`.

    Raises:
      TypeError: If `dataset` and this iterator do not have a compatible
        element structure.
    """
    with ops.name_scope(name, "make_initializer") as name:
      nest.assert_same_structure(self._output_types, dataset.output_types)
      nest.assert_same_structure(self._output_shapes, dataset.output_shapes)
      for iterator_dtype, dataset_dtype in zip(
          nest.flatten(self._output_types), nest.flatten(dataset.output_types)):
        if iterator_dtype != dataset_dtype:
          raise TypeError(
              "Expected output types %r but got dataset with output types %r." %
              (self._output_types, dataset.output_types))
      for iterator_shape, dataset_shape in zip(
          nest.flatten(self._output_shapes),
          nest.flatten(dataset.output_shapes)):
        if not iterator_shape.is_compatible_with(dataset_shape):
          raise TypeError("Expected output shapes compatible with %r but got "
                          "dataset with output shapes %r." %
                          (self._output_shapes, dataset.output_shapes))
    with ops.colocate_with(self._iterator_resource):
      return gen_dataset_ops.make_iterator(
          dataset._as_variant_tensor(), self._iterator_resource, name=name)  # pylint: disable=protected-access
Esempio n. 7
0
    def make_initializer(self, dataset, name=None):
        """Returns a `tf.Operation` that initializes this iterator on `dataset`.

    Args:
      dataset: A `Dataset` with compatible structure to this iterator.
      name: (Optional.) A name for the created operation.

    Returns:
      A `tf.Operation` that can be run to initialize this iterator on the given
      `dataset`.

    Raises:
      TypeError: If `dataset` and this iterator do not have a compatible
        element structure.
    """
        with ops.name_scope(name, "make_initializer") as name:
            # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due
            # to that creating a circular dependency.
            # pylint: disable=protected-access
            dataset_output_types = nest.map_structure(
                lambda component_spec: component_spec._to_legacy_output_types(
                ), dataset.element_spec)
            dataset_output_shapes = nest.map_structure(
                lambda component_spec: component_spec._to_legacy_output_shapes(
                ), dataset.element_spec)
            dataset_output_classes = nest.map_structure(
                lambda component_spec: component_spec.
                _to_legacy_output_classes(), dataset.element_spec)
            # pylint: enable=protected-access

            nest.assert_same_structure(self.output_types, dataset_output_types)
            nest.assert_same_structure(self.output_shapes,
                                       dataset_output_shapes)
            for iterator_class, dataset_class in zip(
                    nest.flatten(self.output_classes),
                    nest.flatten(dataset_output_classes)):
                if iterator_class is not dataset_class:
                    raise TypeError(
                        "Expected output classes %r but got dataset with output class %r."
                        % (self.output_classes, dataset_output_classes))
            for iterator_dtype, dataset_dtype in zip(
                    nest.flatten(self.output_types),
                    nest.flatten(dataset_output_types)):
                if iterator_dtype != dataset_dtype:
                    raise TypeError(
                        "Expected output types %r but got dataset with output types %r."
                        % (self.output_types, dataset_output_types))
            for iterator_shape, dataset_shape in zip(
                    nest.flatten(self.output_shapes),
                    nest.flatten(dataset_output_shapes)):
                if not iterator_shape.is_compatible_with(dataset_shape):
                    raise TypeError(
                        "Expected output shapes compatible with %r but got "
                        "dataset with output shapes %r." %
                        (self.output_shapes, dataset_output_shapes))

        with ops.colocate_with(dataset._variant_tensor):
            return gen_dataset_ops.make_iterator(dataset._variant_tensor,
                                                 self._iterator_resource,
                                                 name=name)  # pylint: disable=protected-access
 def _eager_reset(self):
   """Resets the MultiDeviceIterator in eager mode."""
   if not context.executing_eagerly():
     raise ValueError("Eager reset is only supported in eager mode.")
   # pylint: disable=protected-access
   self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
       self._dataset._variant_tensor,
       self._multi_device_iterator_resource,
       max_buffer_size=self._max_buffer_size)
   for i, device in enumerate(self._devices):
     with ops.device(device):
       ds = self._create_device_dataset(i)
       # Reset the device iterator resources with the new dataset.
       ds_variant = ds._variant_tensor
       gen_dataset_ops.make_iterator(
           ds_variant, self._device_iterators[i]._iterator_resource)
Esempio n. 9
0
    def __init__(self, dataset):
        """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

        if not context.executing_eagerly():
            raise RuntimeError(
                "{} objects can only be used when eager execution is enabled, use "
                "tf.data.Dataset.make_initializable_iterator or "
                "tf.data.Dataset.make_one_shot_iterator for graph construction"
                .format(type(self)))
        with ops.device("/device:CPU:0"):
            ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
            self._output_classes = dataset.output_classes
            self._output_types = dataset.output_types
            self._output_shapes = dataset.output_shapes
            self._flat_output_types = nest.flatten(
                sparse.as_dense_types(self._output_types,
                                      self._output_classes))
            self._flat_output_shapes = nest.flatten(
                sparse.as_dense_shapes(self._output_shapes,
                                       self._output_classes))
            self._resource = gen_dataset_ops.iterator(
                shared_name="",
                container=_generate_shared_name("eageriterator"),
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            gen_dataset_ops.make_iterator(ds_variant, self._resource)
            # Delete the resource when this object is deleted
            self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
                handle=self._resource, handle_device="/device:CPU:0")
        self._device = context.context().device_name
Esempio n. 10
0
  def __init__(self, dataset):
    """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

    if not context.executing_eagerly():
      raise RuntimeError(
          "{} objects can only be used when eager execution is enabled, use "
          "tf.data.Dataset.make_initializable_iterator or "
          "tf.data.Dataset.make_one_shot_iterator for graph construction".
          format(type(self)))
    with ops.device("/device:CPU:0"):
      ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
      self._output_classes = dataset.output_classes
      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes
      self._flat_output_types = nest.flatten(
          sparse.as_dense_types(self._output_types, self._output_classes))
      self._flat_output_shapes = nest.flatten(
          sparse.as_dense_shapes(self._output_shapes, self._output_classes))
      self._resource = gen_dataset_ops.iterator(
          shared_name="",
          container=_generate_shared_name("eageriterator"),
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)
      gen_dataset_ops.make_iterator(ds_variant, self._resource)
      # Delete the resource when this object is deleted
      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
          handle=self._resource, handle_device="/device:CPU:0")
    self._device = context.context().device_name
Esempio n. 11
0
 def _create_iterator(self, dataset):
   # pylint: disable=protected-access
   dataset = dataset._apply_options()
   ds_variant = dataset._variant_tensor
   self._structure = dataset._element_structure
   self._flat_output_types = self._structure._flat_types
   self._flat_output_shapes = self._structure._flat_shapes
   with ops.colocate_with(ds_variant):
     self._iterator_resource, self._deleter = (
         gen_dataset_ops.anonymous_iterator_v2(
             output_types=self._flat_output_types,
             output_shapes=self._flat_output_shapes))
     gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
     # Delete the resource when this object is deleted
     self._resource_deleter = IteratorResourceDeleter(
         handle=self._iterator_resource,
         device=self._device,
         deleter=self._deleter)
Esempio n. 12
0
    def __init__(self, dataset):
        """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

        if not context.executing_eagerly():
            raise RuntimeError(
                "{} objects can only be used when eager execution is enabled, use "
                "tf.data.Dataset.make_initializable_iterator or "
                "tf.data.Dataset.make_one_shot_iterator for graph construction"
                .format(type(self)))
        self._device = context.context().device_name
        with ops.device("/cpu:0"):
            # pylint: disable=protected-access
            dataset = dataset._apply_options()
            ds_variant = dataset._variant_tensor
            self._structure = structure_lib.convert_legacy_structure(
                dataset.output_types, dataset.output_shapes,
                dataset.output_classes)
            self._flat_output_types = self._structure._flat_types
            self._flat_output_shapes = self._structure._flat_shapes
            with ops.colocate_with(ds_variant):
                self._resource = gen_dataset_ops.anonymous_iterator(
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes)
                gen_dataset_ops.make_iterator(ds_variant, self._resource)
                # Delete the resource when this object is deleted
                self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
                    handle=self._resource, handle_device=self._device)
Esempio n. 13
0
  def __init__(self, dataset):
    """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

    if not context.executing_eagerly():
      raise RuntimeError(
          "{} objects can only be used when eager execution is enabled, use "
          "tf.data.Dataset.make_initializable_iterator or "
          "tf.data.Dataset.make_one_shot_iterator for graph construction".
          format(type(self)))
    self._device = context.context().device_name
    with ops.device("/cpu:0"):
      # pylint: disable=protected-access
      dataset = dataset._apply_options()
      ds_variant = dataset._variant_tensor
      self._structure = structure_lib.convert_legacy_structure(
          dataset.output_types, dataset.output_shapes, dataset.output_classes)
      self._flat_output_types = self._structure._flat_types
      self._flat_output_shapes = self._structure._flat_shapes
      with ops.colocate_with(ds_variant):
        self._resource = gen_dataset_ops.anonymous_iterator(
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)
        gen_dataset_ops.make_iterator(ds_variant, self._resource)
        # Delete the resource when this object is deleted
        self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
            handle=self._resource, handle_device=self._device)
Esempio n. 14
0
    def _init_func():
      """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
      ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
      resource = gen_dataset_ops.anonymous_iterator(
          **dataset_ops.flat_structure(self._input_dataset))
      with ops.control_dependencies(
          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
        return gen_dataset_ops.iterator_to_string_handle(resource)
Esempio n. 15
0
 def _eager_reset(self):
   """Resets the MultiDeviceIterator in eager mode."""
   if not ops.executing_eagerly_outside_functions():
     raise ValueError(
         "Resetting a multi-device iterator is only supported in the eager "
         "mode.")
   # pylint: disable=protected-access
   self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
       self._dataset._variant_tensor,
       self._multi_device_iterator_resource,
       max_buffer_size=self._max_buffer_size)
   for i, device in enumerate(self._devices):
     with ops.device(device):
       ds = _create_device_dataset(self._prototype_device_datasets[i],
                                   self._incarnation_id,
                                   self._prefetch_buffer_size,
                                   self._experimental_slack)
       # Reset the device iterator resources with the new dataset.
       ds_variant = ds._variant_tensor
       gen_dataset_ops.make_iterator(
           ds_variant, self._device_iterators[i]._iterator_resource)
Esempio n. 16
0
    def _init_func():
      """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
      ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
      resource = gen_dataset_ops.anonymous_iterator(
          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
      with ops.control_dependencies(
          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
        return gen_dataset_ops.iterator_to_string_handle(resource)
Esempio n. 17
0
    def __init__(self, dataset):
        """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

        self._device = context.context().device_name
        with ops.device("/cpu:0"):
            # pylint: disable=protected-access
            dataset = dataset._apply_options()
            ds_variant = dataset._variant_tensor
            self._structure = dataset._element_structure
            self._flat_output_types = self._structure._flat_types
            self._flat_output_shapes = self._structure._flat_shapes
            with ops.colocate_with(ds_variant):
                self._iterator_resource, self._deleter = (
                    gen_dataset_ops.anonymous_iterator_v2(
                        output_types=self._flat_output_types,
                        output_shapes=self._flat_output_shapes))
                gen_dataset_ops.make_iterator(ds_variant,
                                              self._iterator_resource)
                # Delete the resource when this object is deleted
                self._resource_deleter = IteratorResourceDeleter(
                    handle=self._iterator_resource,
                    device=self._device,
                    deleter=self._deleter)
Esempio n. 18
0
  def __init__(self, dataset):
    """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

    self._device = context.context().device_name
    with ops.device("/cpu:0"):
      # pylint: disable=protected-access
      dataset = dataset._apply_options()
      ds_variant = dataset._variant_tensor
      self._structure = dataset._element_structure
      self._flat_output_types = self._structure._flat_types
      self._flat_output_shapes = self._structure._flat_shapes
      with ops.colocate_with(ds_variant):
        self._iterator_resource, self._deleter = (
            gen_dataset_ops.anonymous_iterator_v2(
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes))
        gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
        # Delete the resource when this object is deleted
        self._resource_deleter = IteratorResourceDeleter(
            handle=self._iterator_resource,
            device=self._device,
            deleter=self._deleter)
Esempio n. 19
0
        def _init_func():
            """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
            ds_variant = gen_dataset_ops.unwrap_dataset_variant(
                wrap_ds_variant)
            resource = gen_dataset_ops.anonymous_iterator(
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            with ops.control_dependencies(
                [gen_dataset_ops.make_iterator(ds_variant, resource)]):
                return gen_dataset_ops.iterator_to_string_handle(resource)
Esempio n. 20
0
        def _init_func():
            """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
            # pylint: disable=protected-access
            ds_variant = self._input_dataset._as_variant_tensor()
            resource = gen_dataset_ops.anonymous_iterator(
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            with ops.control_dependencies(
                [gen_dataset_ops.make_iterator(ds_variant, resource)]):
                return gen_dataset_ops.iterator_to_string_handle(resource)
Esempio n. 21
0
    def _init_func():
      """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
      # pylint: disable=protected-access
      ds_variant = self._input_dataset._as_variant_tensor()
      resource = gen_dataset_ops.anonymous_iterator(
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)
      with ops.control_dependencies(
          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
        return gen_dataset_ops.iterator_to_string_handle(resource)
Esempio n. 22
0
    def make_initializer(self, dataset, name=None):
        """Returns a `tf.Operation` that initializes this iterator on `dataset`.

    Args:
      dataset: A `Dataset` with compatible structure to this iterator.
      name: (Optional.) A name for the created operation.

    Returns:
      A `tf.Operation` that can be run to initialize this iterator on the given
      `dataset`.

    Raises:
      TypeError: If `dataset` and this iterator do not have a compatible
        element structure.
    """
        with ops.name_scope(name, "make_initializer") as name:
            nest.assert_same_structure(self._output_types,
                                       dataset.output_types)
            nest.assert_same_structure(self._output_shapes,
                                       dataset.output_shapes)
            for iterator_class, dataset_class in zip(
                    nest.flatten(self._output_classes),
                    nest.flatten(dataset.output_classes)):
                if iterator_class is not dataset_class:
                    raise TypeError(
                        "Expected output classes %r but got dataset with output class %r."
                        % (self._output_classes, dataset.output_classes))
            for iterator_dtype, dataset_dtype in zip(
                    nest.flatten(self._output_types),
                    nest.flatten(dataset.output_types)):
                if iterator_dtype != dataset_dtype:
                    raise TypeError(
                        "Expected output types %r but got dataset with output types %r."
                        % (self._output_types, dataset.output_types))
            for iterator_shape, dataset_shape in zip(
                    nest.flatten(self._output_shapes),
                    nest.flatten(dataset.output_shapes)):
                if not iterator_shape.is_compatible_with(dataset_shape):
                    raise TypeError(
                        "Expected output shapes compatible with %r but got "
                        "dataset with output shapes %r." %
                        (self._output_shapes, dataset.output_shapes))
        with ops.colocate_with(self._iterator_resource):
            return gen_dataset_ops.make_iterator(dataset._as_variant_tensor(),
                                                 self._iterator_resource,
                                                 name=name)  # pylint: disable=protected-access
Esempio n. 23
0
    def __init__(self, dataset):
        """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

        if not context.in_eager_mode():
            raise RuntimeError(
                "{} objects can only be used when eager execution is enabled, use "
                "tf.data.Dataset.make_iterator or "
                "tf.data.Dataset.make_one_shot_iterator for graph construction"
                .format(type(self)))
        with ops.device("/device:CPU:0"):
            ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
            self._output_types = dataset.output_types
            self._output_shapes = dataset.output_shapes
            self._flat_output_types = nest.flatten(dataset.output_types)
            self._flat_output_shapes = nest.flatten(dataset.output_shapes)
            self._resource = gen_dataset_ops.iterator(
                container="",
                shared_name=_generate_shared_name("eager_iterator"),
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            gen_dataset_ops.make_iterator(ds_variant, self._resource)
            # Delete the resource when this object is deleted
            self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
                handle=self._resource, handle_device="/device:CPU:0")
        self._device = context.context().device_name
        self._buffer_resource_handle = None
        if not context.context().device_spec.device_type:
            is_remote_device = False
        else:
            is_remote_device = context.context(
            ).device_spec.device_type != "CPU"
        if is_remote_device:
            with ops.device("/device:CPU:0"):
                iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
                    self._resource)

                @function.Defun(dtypes.string)
                def remote_fn(h):
                    remote_iterator = iterator_ops.Iterator.from_string_handle(
                        h, self._output_types, self._output_shapes)
                    return remote_iterator.get_next()

                remote_fn.add_to_graph(None)
                target = constant_op.constant("/device:CPU:0")
            with ops.device(self._device):
                self._buffer_resource_handle = prefetching_ops.function_buffering_resource(
                    string_arg=iter_string_handle,
                    f=remote_fn,
                    target_device=target,
                    buffer_size=10,
                    thread_pool_size=1,
                    container="",
                    shared_name=_generate_shared_name(
                        "function_buffer_resource"))
                self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(
                    handle=self._buffer_resource_handle,
                    handle_device=self._device)
Esempio n. 24
0
  def __init__(self, dataset):
    """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

    if not context.in_eager_mode():
      raise RuntimeError(
          "{} objects can only be used when eager execution is enabled, use "
          "tf.data.Dataset.make_iterator or "
          "tf.data.Dataset.make_one_shot_iterator for graph construction".
          format(type(self)))
    with ops.device("/device:CPU:0"):
      ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes
      self._flat_output_types = nest.flatten(dataset.output_types)
      self._flat_output_shapes = nest.flatten(dataset.output_shapes)
      self._resource = gen_dataset_ops.iterator(
          container="",
          shared_name=_generate_shared_name("eager_iterator"),
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)
      gen_dataset_ops.make_iterator(ds_variant, self._resource)
      # Delete the resource when this object is deleted
      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
          handle=self._resource, handle_device="/device:CPU:0")
    self._device = context.context().device_name
    self._buffer_resource_handle = None
    if not context.context().device_spec.device_type:
      is_remote_device = False
    else:
      is_remote_device = context.context().device_spec.device_type != "CPU"
    if is_remote_device:
      with ops.device("/device:CPU:0"):
        iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
            self._resource)

        @function.Defun(dtypes.string)
        def remote_fn(h):
          remote_iterator = iterator_ops.Iterator.from_string_handle(
              h, self._output_types, self._output_shapes)
          return remote_iterator.get_next()

        remote_fn.add_to_graph(None)
        target = constant_op.constant("/device:CPU:0")
      with ops.device(self._device):
        self._buffer_resource_handle = prefetching_ops.function_buffering_resource(
            string_arg=iter_string_handle,
            f=remote_fn,
            target_device=target,
            buffer_size=10,
            thread_pool_size=1,
            container="",
            shared_name=_generate_shared_name("function_buffer_resource"))
        self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(
            handle=self._buffer_resource_handle, handle_device=self._device)