def materialize(self, shared_name=None, container=None):
    """Materialize creates a MaterializedIndexedDataset.

    IndexedDatasets can be combined through operations such as TBD. Therefore,
    they are only materialized when absolutely required.

    Args:
      shared_name: a string for the shared name to use for the resource.
      container: a string for the container to store the resource.

    Returns:
      A MaterializedIndexedDataset.
    """
    if container is None:
      container = ""
    if shared_name is None:
      shared_name = ""
    materialized_resource = (
        ged_ops.experimental_materialized_index_dataset_handle(
            container=container,
            shared_name=shared_name,
            output_types=nest.flatten(
                sparse.as_dense_types(self.output_types, self.output_classes)),
            output_shapes=nest.flatten(
                sparse.as_dense_types(self.output_shapes,
                                      self.output_classes))))

    with ops.colocate_with(materialized_resource):
      materializer = ged_ops.experimental_indexed_dataset_materialize(
          self._as_variant_tensor(), materialized_resource)
    return MaterializedIndexedDataset(materialized_resource, materializer,
                                      self.output_classes, self.output_types,
                                      self.output_shapes)
Esempio n. 2
0
  def _make_key_func(self, key_func, input_dataset):
    """Make wrapping Defun for key_func."""

    @function.Defun(*nest.flatten(
        sparse.as_dense_types(input_dataset.output_types,
                              input_dataset.output_classes)))
    def tf_key_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      # pylint: disable=protected-access
      if dataset_ops._should_unpack_args(nested_args):
        ret = key_func(*nested_args)
      # pylint: enable=protected-access
      else:
        ret = key_func(nested_args)
      ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
      if ret.dtype != dtypes.int64:
        raise ValueError("`key_func` must return a single tf.int64 tensor.")
      return ret

    self._key_func = tf_key_func
    self._key_func.add_to_graph(ops.get_default_graph())
Esempio n. 3
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.ignore_errors_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)))
Esempio n. 4
0
  def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    self._get_next_call_count += 1
    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.as_dense_types(
                                          self._output_types,
                                          self._output_classes)),
                                  output_shapes=nest.flatten(
                                      sparse.as_dense_shapes(
                                          self._output_shapes,
                                          self._output_classes)),
                                  name=name)), self._output_types,
        self._output_shapes, self._output_classes)
Esempio n. 5
0
  def get_next(self, name=None):
    """See `tf.data.Iterator.get_next`."""
    self._get_next_call_count += 1
    if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)

    flat_result = []
    # TODO(priyag): This will fail if the input size (typically number of
    # batches) is not divisible by number of devices.
    # How do we handle that more gracefully / let the user know?
    for buffer_resource in self._buffering_resources:
      flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
          buffer_resource,
          output_types=data_nest.flatten(sparse.as_dense_types(
              self.output_types, self.output_classes)), name=name)

      ret = sparse.deserialize_sparse_tensors(
          data_nest.pack_sequence_as(self.output_types, flat_ret),
          self.output_types, self.output_shapes, self.output_classes)

      for tensor, shape in zip(
          data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
        if isinstance(tensor, ops.Tensor):
          tensor.set_shape(shape)
      flat_result.append(ret)

    return nest.pack_sequence_as(self._devices, flat_result)
Esempio n. 6
0
def get_next_as_optional(iterator):
  """Returns an `Optional` that contains the next value from the iterator.

  If `iterator` has reached the end of the sequence, the returned `Optional`
  will have no value.

  Args:
    iterator: A `tf.data.Iterator` object.

  Returns:
    An `Optional` object representing the next value from the iterator (if it
    has one) or no value.
  """
  # pylint: disable=protected-access
  return optional_ops._OptionalImpl(
      gen_dataset_ops.iterator_get_next_as_optional(
          iterator._iterator_resource,
          output_types=nest.flatten(
              sparse.as_dense_types(iterator.output_types,
                                    iterator.output_classes)),
          output_shapes=nest.flatten(
              sparse.as_dense_shapes(iterator.output_shapes,
                                     iterator.output_classes))),
      structure.Structure._from_legacy_structure(iterator.output_types,
                                                 iterator.output_shapes,
                                                 iterator.output_classes))
Esempio n. 7
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.set_stats_aggregator_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._stats_aggregator._resource,  # pylint: disable=protected-access
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
Esempio n. 8
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.random_dataset(
       seed=self._seed,
       seed2=self._seed2,
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)))
  def get(self, index):
    """Get retrieves a value (or set of values) from the IndexedDataset.

    Args:
      index: A uint64 scalar or vector tensor with the indices to retrieve.

    Returns:
      A tensor containing the values corresponding to `index`.
    """
    # TODO(saeta): nest.pack_sequence_as(...)
    return ged_ops.experimental_indexed_dataset_get(
        self._materialized_resource,
        index,
        output_types=nest.flatten(
            sparse.as_dense_types(self._output_types, self._output_classes)),
        output_shapes=nest.flatten(
            sparse.as_dense_types(self._output_shapes, self._output_classes)))
Esempio n. 10
0
 def _as_variant_tensor(self):
   return self._op_function(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._tag,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
Esempio n. 11
0
 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   return gen_dataset_ops.directed_interleave_dataset(
       self._selector_input._as_variant_tensor(),
       [data_input._as_variant_tensor() for data_input in self._data_inputs],
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)))
Esempio n. 12
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.slide_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       window_size=self._window_size,
       stride=self._stride,
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)))
Esempio n. 13
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.dense_to_sparse_batch_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._batch_size,
       row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape),  # pylint: disable=protected-access
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)))
Esempio n. 14
0
  def __init__(self,
               input_dataset,
               one_shot,
               devices,
               buffer_size,
               shared_name=None):
    self._input_dataset = input_dataset
    self._get_next_call_count = 0
    self._one_shot = one_shot
    if shared_name is None:
      shared_name = ""
    self._devices = devices

    if self._one_shot:
      self._input_iterator = input_dataset.make_one_shot_iterator()
    else:
      self._input_iterator = iterator_ops.Iterator.from_structure(
          self._input_dataset.output_types, self._input_dataset.output_shapes,
          shared_name, self._input_dataset.output_classes)
    input_iterator_handle = self._input_iterator.string_handle()

    @function.Defun(dtypes.string)
    def _prefetch_fn(handle):
      """Prefetches one element from `input_iterator`."""
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, self._input_iterator.output_types,
          self._input_iterator.output_shapes,
          self._input_iterator.output_classes)
      ret = remote_iterator.get_next()
      return nest.flatten(sparse.serialize_sparse_tensors(ret))

    target_device = ged_ops.experimental_iterator_get_device(
        self._input_iterator._iterator_resource)
    self._buffering_resources = []
    for device in nest.flatten(self._devices):
      with ops.device(device):
        buffer_resource_handle = prefetching_ops.function_buffering_resource(
            f=_prefetch_fn,
            output_types=data_nest.flatten(
                sparse.as_dense_types(self._input_dataset.output_types,
                                      self._input_dataset.output_classes)),
            target_device=target_device,
            string_arg=input_iterator_handle,
            buffer_size=buffer_size,
            shared_name=shared_name)
        self._buffering_resources.append(buffer_resource_handle)

    if not self._one_shot:
      reset_ops = []
      for buffer_resource in self._buffering_resources:
        reset_ops.append(
            ged_ops.experimental_function_buffering_resource_reset(
                buffer_resource))
      with ops.control_dependencies(reset_ops):
        self._initializer = self._input_iterator.make_initializer(
            self._input_dataset)
Esempio n. 15
0
  def __init__(self,
               dataset,
               devices,
               prefetch_buffer_size=1,
               source_device="/cpu:0"):
    self._dataset = dataset
    self._devices = devices
    self._source_device = source_device
    self._source_device_tensor = ops.convert_to_tensor(source_device)

    self._flat_output_shapes = nest.flatten(
        sparse.as_dense_shapes(self._dataset.output_shapes,
                               self._dataset.output_classes))
    self._flat_output_types = nest.flatten(
        sparse.as_dense_types(self._dataset.output_types,
                              self._dataset.output_classes))

    # Create the MultiDeviceIterator.
    with ops.device(self._source_device):
      self._multi_device_iterator_resource = (
          gen_dataset_ops.multi_device_iterator(
              devices=self._devices,
              shared_name="",
              container="",
              output_types=self._flat_output_types,
              output_shapes=self._flat_output_shapes))

      # The incarnation ID is used to ensure consistency between the per-device
      # iterators and the multi-device iterator.
      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
          self._dataset._as_variant_tensor(),  # pylint: disable=protected-access
          self._multi_device_iterator_resource)

    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
    # initialize the device side of the pipeline. This would allow the
    # MultiDeviceIterator to choose, for example, to move some transformations
    # into the device side from its input. It might be useful in rewriting.
    # Create the per device iterators.
    self._device_iterators = []
    i = 0
    for device in self._devices:
      ds = _PerDeviceGenerator(
          i, self._multi_device_iterator_resource, self._incarnation_id,
          self._source_device_tensor, device, self._dataset.output_shapes,
          self._dataset.output_types, self._dataset.output_classes)
      ds = ds.prefetch(prefetch_buffer_size)
      with ops.device(device):
        self._device_iterators.append(ds.make_initializable_iterator())
      i += 1

    device_iterator_initializers = [
        iterator.initializer for iterator in self._device_iterators
    ]
    self._initializer = control_flow_ops.group(*device_iterator_initializers)
Esempio n. 16
0
 def _as_variant_tensor(self):
   input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
   return gen_dataset_ops.scan_dataset(
       input_t,
       nest.flatten(self._initial_state),
       self._scan_func.captured_inputs,
       f=self._scan_func,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
Esempio n. 17
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.parallel_interleave_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._map_func.captured_inputs,
       self._cycle_length,
       self._block_length,
       self._sloppy,
       f=self._map_func,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
Esempio n. 18
0
  def __init__(self, input_dataset, map_func, cycle_length, block_length,
               sloppy, buffer_output_elements, prefetch_input_elements):
    """See `tf.contrib.data.parallel_interleave()` for details."""
    super(ParallelInterleaveDataset, self).__init__()
    self._input_dataset = input_dataset

    @function.Defun(*nest.flatten(
        sparse.as_dense_types(input_dataset.output_types,
                              input_dataset.output_classes)))
    def tf_map_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      if dataset_ops._should_unpack_args(nested_args):  # pylint: disable=protected-access
        dataset = map_func(*nested_args)
      else:
        dataset = map_func(nested_args)

      if not isinstance(dataset, dataset_ops.Dataset):
        raise TypeError("`map_func` must return a `Dataset` object.")

      self._output_classes = dataset.output_classes
      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes

      return dataset._as_variant_tensor()  # pylint: disable=protected-access

    self._map_func = tf_map_func
    self._map_func.add_to_graph(ops.get_default_graph())

    self._cycle_length = ops.convert_to_tensor(
        cycle_length, dtype=dtypes.int64, name="cycle_length")
    self._block_length = ops.convert_to_tensor(
        block_length, dtype=dtypes.int64, name="block_length")
    self._sloppy = ops.convert_to_tensor(
        sloppy, dtype=dtypes.bool, name="sloppy")
    self._buffer_output_elements = convert.optional_param_to_tensor(
        "buffer_output_elements",
        buffer_output_elements,
        argument_default=2 * block_length)
    self._prefetch_input_elements = convert.optional_param_to_tensor(
        "prefetch_input_elements",
        prefetch_input_elements,
        argument_default=2 * cycle_length)
  def __init__(self,
               input_dataset,
               one_shot,
               device,
               buffer_size,
               shared_name=None):
    self._input_dataset = input_dataset
    self._get_next_call_count = 0
    self._one_shot = one_shot
    if shared_name is None:
      shared_name = ""

    if self._one_shot:
      self._input_iterator = input_dataset.make_one_shot_iterator()
    else:
      self._input_iterator = iterator_ops.Iterator.from_structure(
          self._input_dataset.output_types, self._input_dataset.output_shapes,
          shared_name, self._input_dataset.output_classes)
    input_iterator_handle = self._input_iterator.string_handle()

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    # handle is a scalar `tf.Tensor` of type `tf.string`
    def _prefetch_fn(handle):
      """Prefetches one element from `input_iterator`."""
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, self._input_iterator.output_types,
          self._input_iterator.output_shapes,
          self._input_iterator.output_classes)
      ret = remote_iterator.get_next()
      return nest.flatten(sparse.serialize_sparse_tensors(ret))

    self._prefetch_fn = _prefetch_fn._get_concrete_function_internal()  # pylint: disable=protected-access

    iterator_device = ged_ops.experimental_iterator_get_device(
        self._input_iterator._iterator_resource)

    with ops.device(device):
      self._buffering_resource = function_buffering_resource(
          f=self._prefetch_fn,
          target_device=iterator_device,
          string_arg=input_iterator_handle,
          buffer_size=buffer_size,
          shared_name=shared_name,
          output_types=nest.flatten(
              sparse.as_dense_types(self._input_dataset.output_types,
                                    self._input_dataset.output_classes)))

    if not self._one_shot:
      reset_op = function_buffering_resource_reset(self._buffering_resource)
      with ops.control_dependencies([reset_op]):
        self._initializer = self._input_iterator.make_initializer(
            self._input_dataset)
Esempio n. 20
0
 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   input_resource = self._input_dataset._as_variant_tensor()
   return gen_dataset_ops.map_and_batch_dataset(
       input_resource,
       self._map_func.captured_inputs,
       f=self._map_func,
       batch_size=self._batch_size,
       num_parallel_batches=self._num_parallel_batches,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
Esempio n. 21
0
 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   input_resource = self._input_dataset._as_variant_tensor()
   return gen_dataset_ops.shuffle_and_repeat_dataset(
       input_resource,
       buffer_size=self._buffer_size,
       count=self._count,
       seed=self._seed,
       seed2=self._seed2,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
Esempio n. 22
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.group_by_window_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._key_func.captured_inputs,
       self._reduce_func.captured_inputs,
       self._window_size_func.captured_inputs,
       key_func=self._key_func,
       reduce_func=self._reduce_func,
       window_size_func=self._window_size_func,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
Esempio n. 23
0
def get_single_element(dataset):
  """Returns the single element in `dataset` as a nested structure of tensors.

  This function enables you to use a @{tf.data.Dataset} in a stateless
  "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}.
  This can be useful when your preprocessing transformations are expressed
  as a `Dataset`, and you want to use the transformation at serving time.
  For example:

  ```python
  input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE])

  def preprocessing_fn(input_str):
    # ...
    return image, label

  dataset = (tf.data.Dataset.from_tensor_slices(input_batch)
             .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
             .batch(BATCH_SIZE))

  image_batch, label_batch = tf.contrib.data.get_single_element(dataset)
  ```

  Args:
    dataset: A @{tf.data.Dataset} object containing a single element.

  Returns:
    A nested structure of @{tf.Tensor} objects, corresponding to the single
    element of `dataset`.

  Raises:
    TypeError: if `dataset` is not a `tf.data.Dataset` object.
    InvalidArgumentError (at runtime): if `dataset` does not contain exactly
      one element.
  """
  if not isinstance(dataset, dataset_ops.Dataset):
    raise TypeError("`dataset` must be a `tf.data.Dataset` object.")

  nested_ret = nest.pack_sequence_as(
      dataset.output_types, gen_dataset_ops.dataset_to_single_element(
          dataset._as_variant_tensor(),  # pylint: disable=protected-access
          output_types=nest.flatten(sparse.as_dense_types(
              dataset.output_types, dataset.output_classes)),
          output_shapes=nest.flatten(sparse.as_dense_shapes(
              dataset.output_shapes, dataset.output_classes))))
  return sparse.deserialize_sparse_tensors(
      nested_ret, dataset.output_types, dataset.output_shapes,
      dataset.output_classes)
Esempio n. 24
0
  def __init__(self, dataset):
    """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

    if not context.executing_eagerly():
      raise RuntimeError(
          "{} objects can only be used when eager execution is enabled, use "
          "tf.data.Dataset.make_initializable_iterator or "
          "tf.data.Dataset.make_one_shot_iterator for graph construction".
          format(type(self)))
    with ops.device("/device:CPU:0"):
      ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
      self._output_classes = dataset.output_classes
      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes
      self._flat_output_types = nest.flatten(
          sparse.as_dense_types(self._output_types, self._output_classes))
      self._flat_output_shapes = nest.flatten(
          sparse.as_dense_shapes(self._output_shapes, self._output_classes))
      self._resource = gen_dataset_ops.iterator(
          shared_name="",
          container=_generate_shared_name("eageriterator"),
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)
      gen_dataset_ops.make_iterator(ds_variant, self._resource)
      # Delete the resource when this object is deleted
      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
          handle=self._resource, handle_device="/device:CPU:0")
    self._device = context.context().device_name
Esempio n. 25
0
 def get_value(self, name=None):
   # TODO(b/110122868): Consolidate the restructuring logic with similar logic
   # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
   with ops.name_scope(name, "OptionalGetValue",
                       [self._variant_tensor]) as scope:
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(
             self._output_types,
             gen_dataset_ops.optional_get_value(
                 self._variant_tensor,
                 name=scope,
                 output_types=nest.flatten(
                     sparse.as_dense_types(self._output_types,
                                           self._output_classes)),
                 output_shapes=nest.flatten(
                     sparse.as_dense_shapes(self._output_shapes,
                                            self._output_classes)))),
         self._output_types, self._output_shapes, self._output_classes)
Esempio n. 26
0
  def _make_finalize_func(self, finalize_func):
    """Make wrapping Defun for finalize_func."""

    @function.Defun(*(nest.flatten(
        sparse.as_dense_types(self._state_types, self._state_classes))))
    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)

    self._finalize_func = tf_finalize_func
    self._finalize_func.add_to_graph(ops.get_default_graph())
Esempio n. 27
0
 def get_next(self, name=None):
   """Returns a nested structure of `tf.Tensor`s containing the next element.
   Args:
     name: (Optional.) A name for the created operation.
   Returns:
     A nested structure of `tf.Tensor` objects.
   """
   return sparse.deserialize_sparse_tensors(
       nest.pack_sequence_as(self._output_types,
                             gen_dataset_ops.iterator_get_next(
                                 self._iterator_resource,
                                 output_types=nest.flatten(
                                     sparse.as_dense_types(
                                         self._output_types,
                                         self._output_classes)),
                                 output_shapes=nest.flatten(
                                     sparse.as_dense_shapes(
                                         self._output_shapes,
                                         self._output_classes)),
                                 name=name)), self._output_types,
       self._output_shapes, self._output_classes)
Esempio n. 28
0
  def get_next(self, name=None):
    """See @{tf.data.Iterator.get_next}."""
    self._get_next_call_count += 1
    if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)

    flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
        self._buffering_resource,
        output_types=nest.flatten(sparse.as_dense_types(
            self.output_types, self.output_classes)), name=name)

    ret = sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self.output_types, flat_ret),
        self.output_types, self.output_shapes, self.output_classes)

    for tensor, shape in zip(
        nest.flatten(ret), nest.flatten(self.output_shapes)):
      if isinstance(tensor, ops.Tensor):
        tensor.set_shape(shape)

    return ret
  def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
               source_device, target_device, output_shapes, output_types,
               output_classes):
    self._target_device = target_device
    self._output_types = output_types
    self._output_shapes = output_shapes
    self._output_classes = output_classes
    self._flat_output_shapes = nest.flatten(
        sparse.as_dense_shapes(self._output_shapes, self._output_classes))
    self._flat_output_types = nest.flatten(
        sparse.as_dense_types(self._output_types, self._output_classes))

    multi_device_iterator_string_handle = (
        gen_dataset_ops.multi_device_iterator_to_string_handle(
            multi_device_iterator_resource))

    @function.defun()
    def _init_func():
      return multi_device_iterator_string_handle

    init_func_concrete = _init_func.get_concrete_function()
    @function.defun()
    def _remote_init_func():
      return functional_ops.remote_call(
          target=source_device,
          args=init_func_concrete.captured_inputs,
          Tout=[dtypes.string],
          f=init_func_concrete)

    self._init_func = _remote_init_func.get_concrete_function()
    self._init_captured_args = self._init_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _next_func(string_handle):
      multi_device_iterator = (
          gen_dataset_ops.multi_device_iterator_from_string_handle(
              string_handle=string_handle,
              output_types=self._flat_output_types,
              output_shapes=self._flat_output_shapes))
      return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
          multi_device_iterator=multi_device_iterator,
          shard_num=shard_num,
          incarnation_id=incarnation_id,
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)

    next_func_concrete = _next_func.get_concrete_function()
    @function.defun_with_attributes(
        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
        attributes={"experimental_ints_on_device": True})
    def _remote_next_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] +
          next_func_concrete.captured_inputs,
          Tout=self._flat_output_types,
          f=next_func_concrete)

    self._next_func = _remote_next_func.get_concrete_function()
    self._next_captured_args = self._next_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _finalize_func(unused_string_handle):
      return array_ops.constant(0, dtypes.int64)

    finalize_func_concrete = _finalize_func.get_concrete_function()
    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_finalize_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] +
          finalize_func_concrete.captured_inputs,
          Tout=[dtypes.int64],
          f=finalize_func_concrete)

    self._finalize_func = _remote_finalize_func.get_concrete_function()
    self._finalize_captured_args = self._finalize_func.captured_inputs
Esempio n. 30
0
  def from_structure(output_types,
                     output_shapes=None,
                     shared_name=None,
                     output_classes=None):
    """Creates a new, uninitialized `Iterator` with the given structure.

    This iterator-constructing method can be used to create an iterator that
    is reusable with many different datasets.

    The returned iterator is not bound to a particular dataset, and it has
    no `initializer`. To initialize the iterator, run the operation returned by
    `Iterator.make_initializer(dataset)`.

    The following is an example

    ```python
    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

    dataset_range = Dataset.range(10)
    range_initializer = iterator.make_initializer(dataset_range)

    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
    evens_initializer = iterator.make_initializer(dataset_evens)

    # Define a model based on the iterator; in this example, the model_fn
    # is expected to take scalar tf.int64 Tensors as input (see
    # the definition of 'iterator' above).
    prediction, loss = model_fn(iterator.get_next())

    # Train for `num_epochs`, where for each epoch, we first iterate over
    # dataset_range, and then iterate over dataset_evens.
    for _ in range(num_epochs):
      # Initialize the iterator to `dataset_range`
      sess.run(range_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break

      # Initialize the iterator to `dataset_evens`
      sess.run(evens_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break
    ```

    Args:
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      shared_name: (Optional.) If non-empty, this iterator will be shared under
        the given name across multiple sessions that share the same devices
        (e.g. when using a remote server).
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.

    Raises:
      TypeError: If the structures of `output_shapes` and `output_types` are
        not the same.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    if shared_name is None:
      shared_name = ""
    iterator_resource = gen_dataset_ops.iterator(
        container="",
        shared_name=shared_name,
        output_types=nest.flatten(
            sparse.as_dense_types(output_types, output_classes)),
        output_shapes=nest.flatten(
            sparse.as_dense_shapes(output_shapes, output_classes)))
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
Esempio n. 31
0
    def __init__(self, input_dataset, initial_state, scan_func):
        """See `scan()` for details."""
        super(_ScanDataset, self).__init__()
        self._input_dataset = input_dataset

        with ops.name_scope("initial_state"):
            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            self._initial_state = nest.pack_sequence_as(
                initial_state, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
                        t, name="component_%d" % i)
                    for i, t in enumerate(nest.flatten(initial_state))
                ])

        # Compute initial values for the state classes, shapes and types based on
        # the initial state. The shapes may be refined by running `tf_scan_func` one
        # or more times below.
        self._state_classes = sparse.get_classes(self._initial_state)
        self._state_shapes = nest.pack_sequence_as(
            self._initial_state,
            [t.get_shape() for t in nest.flatten(self._initial_state)])
        self._state_types = nest.pack_sequence_as(
            self._initial_state,
            [t.dtype for t in nest.flatten(self._initial_state)])

        # Will be populated by calling `tf_scan_func`.
        self._output_classes = None
        self._output_shapes = None
        self._output_types = None

        # Iteratively rerun the scan function until reaching a fixed point on
        # `self._state_shapes`.
        need_to_rerun = True
        while need_to_rerun:

            # Create a list in which `tf_scan_func` will store the new shapes.
            flat_new_state_shapes = []

            @function.Defun(*(nest.flatten(
                sparse.as_dense_types(
                    self._state_types, self._state_classes)) + nest.flatten(
                        sparse.as_dense_types(input_dataset.output_types,
                                              input_dataset.output_classes))))
            def tf_scan_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                # Pass in shape information from the state and input_dataset.
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                print(self._state_classes)
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                print(input_dataset.output_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = scan_func(nested_state_args, nested_input_args)
                if not isinstance(ret, collections.Sequence) or len(ret) != 2:
                    raise TypeError(
                        "The scan function must return a pair comprising the "
                        "new state and the output value.")

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])
                new_state, output_value = ret

                # Extract and validate class information from the returned values.
                for t, clazz in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_classes)):
                    if not isinstance(t, clazz):
                        raise TypeError(
                            "The element classes for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_classes,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [type(t) for t in nest.flatten(new_state)])))
                self._output_classes = sparse.get_classes(output_value)

                # Extract shape information from the returned values.
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in nest.flatten(new_state)])
                self._output_shapes = nest.pack_sequence_as(
                    output_value,
                    [t.get_shape() for t in nest.flatten(output_value)])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in nest.flatten(new_state)])))
                self._output_types = nest.pack_sequence_as(
                    output_value,
                    [t.dtype for t in nest.flatten(output_value)])

                # Serialize any sparse tensors.
                new_state = nest.pack_sequence_as(new_state, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(new_state))
                ])
                output_value = nest.pack_sequence_as(output_value, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(output_value))
                ])
                return nest.flatten(new_state) + nest.flatten(output_value)

            # Use the private method that will execute `tf_scan_func` but delay
            # adding it to the graph in case we need to rerun the function.
            tf_scan_func._create_definition_if_needed()  # pylint: disable=protected-access

            flat_state_shapes = nest.flatten(self._state_shapes)
            weakened_state_shapes = [
                original.most_specific_compatible_shape(new) for original, new
                in zip(flat_state_shapes, flat_new_state_shapes)
            ]

            need_to_rerun = False
            for original_shape, weakened_shape in zip(flat_state_shapes,
                                                      weakened_state_shapes):
                if original_shape.ndims is not None and (
                        weakened_shape.ndims is None or
                        original_shape.as_list() != weakened_shape.as_list()):
                    need_to_rerun = True
                    break

            if need_to_rerun:
                # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun
                # `tf_scan_func`.
                self._state_shapes = nest.pack_sequence_as(
                    self._state_shapes, weakened_state_shapes)

        self._scan_func = tf_scan_func
        self._scan_func.add_to_graph(ops.get_default_graph())
Esempio n. 32
0
  def from_string_handle(string_handle,
                         output_types,
                         output_shapes=None,
                         output_classes=None):
    """Creates a new, uninitialized `Iterator` based on the given handle.

    This method allows you to define a "feedable" iterator where you can choose
    between concrete iterators by feeding a value in a @{tf.Session.run} call.
    In that case, `string_handle` would a @{tf.placeholder}, and you would feed
    it with the value of @{tf.data.Iterator.string_handle} in each step.

    For example, if you had two iterators that marked the current position in
    a training dataset and a test dataset, you could choose which to use in
    each step as follows:

    ```python
    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    train_iterator_handle = sess.run(train_iterator.string_handle())

    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    test_iterator_handle = sess.run(test_iterator.string_handle())

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_iterator.output_types)

    next_element = iterator.get_next()
    loss = f(next_element)

    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
    ```

    Args:
      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
        to a handle produced by the `Iterator.string_handle()` method.
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
    iterator_resource = gen_dataset_ops.iterator_from_string_handle(
        string_handle,
        output_types=nest.flatten(
            sparse.as_dense_types(output_types, output_classes)),
        output_shapes=nest.flatten(
            sparse.as_dense_shapes(output_shapes, output_classes)))
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
 def testAsDenseTypes(self):
   test_cases = (
       {
           "types": (),
           "classes": (),
           "expected": ()
       },
       {
           "types": dtypes.int32,
           "classes": ops.Tensor,
           "expected": dtypes.int32
       },
       {
           "types": dtypes.int32,
           "classes": sparse_tensor.SparseTensor,
           "expected": dtypes.string
       },
       {
           "types": (dtypes.int32),
           "classes": (ops.Tensor),
           "expected": (dtypes.int32)
       },
       {
           "types": (dtypes.int32),
           "classes": (sparse_tensor.SparseTensor),
           "expected": (dtypes.string)
       },
       {
           "types": (dtypes.int32, ()),
           "classes": (ops.Tensor, ()),
           "expected": (dtypes.int32, ())
       },
       {
           "types": ((), dtypes.int32),
           "classes": ((), ops.Tensor),
           "expected": ((), dtypes.int32)
       },
       {
           "types": (dtypes.int32, ()),
           "classes": (sparse_tensor.SparseTensor, ()),
           "expected": (dtypes.string, ())
       },
       {
           "types": ((), dtypes.int32),
           "classes": ((), sparse_tensor.SparseTensor),
           "expected": ((), dtypes.string)
       },
       {
           "types": (dtypes.int32, (), dtypes.int32),
           "classes": (ops.Tensor, (), ops.Tensor),
           "expected": (dtypes.int32, (), dtypes.int32)
       },
       {
           "types": (dtypes.int32, (), dtypes.int32),
           "classes": (sparse_tensor.SparseTensor, (),
                       sparse_tensor.SparseTensor),
           "expected": (dtypes.string, (), dtypes.string)
       },
       {
           "types": ((), dtypes.int32, ()),
           "classes": ((), ops.Tensor, ()),
           "expected": ((), dtypes.int32, ())
       },
       {
           "types": ((), dtypes.int32, ()),
           "classes": ((), sparse_tensor.SparseTensor, ()),
           "expected": ((), dtypes.string, ())
       },
   )
   for test_case in test_cases:
     self.assertEqual(
         sparse.as_dense_types(test_case["types"], test_case["classes"]),
         test_case["expected"])
Esempio n. 34
0
    def _make_reduce_func(self, reduce_func, input_dataset):
        """Make wrapping Defun for reduce_func."""

        # Iteratively rerun the reduce function until reaching a fixed point on
        # `self._state_shapes`.
        need_to_rerun = True
        while need_to_rerun:

            # Create a list in which `tf_reduce_func` will store the new shapes.
            flat_new_state_shapes = []

            @function.Defun(*(nest.flatten(
                sparse.as_dense_types(
                    self._state_types, self._state_classes)) + nest.flatten(
                        sparse.as_dense_types(input_dataset.output_types,
                                              input_dataset.output_classes))))
            def tf_reduce_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = reduce_func(nested_state_args, nested_input_args)

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])

                # Extract shape information from the returned values.
                flat_new_state = nest.flatten(ret)
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in flat_new_state])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(flat_new_state,
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in flat_new_state])))

                # Serialize any sparse tensors.
                ret = nest.pack_sequence_as(ret, [
                    t
                    for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
                ])
                return nest.flatten(ret)

            # Use the private method that will execute `tf_reduce_func` but delay
            # adding it to the graph in case we need to rerun the function.
            tf_reduce_func._create_definition_if_needed()  # pylint: disable=protected-access

            flat_state_shapes = nest.flatten(self._state_shapes)
            weakened_state_shapes = [
                old.most_specific_compatible_shape(new)
                for old, new in zip(flat_state_shapes, flat_new_state_shapes)
            ]

            need_to_rerun = False
            for old_shape, weakened_shape in zip(flat_state_shapes,
                                                 weakened_state_shapes):
                if old_shape.ndims is not None and (
                        weakened_shape.ndims is None
                        or old_shape.as_list() != weakened_shape.as_list()):
                    need_to_rerun = True
                    break

            if need_to_rerun:
                self._state_shapes = nest.pack_sequence_as(
                    self._state_shapes, weakened_state_shapes)

        self._reduce_func = tf_reduce_func
        self._reduce_func.add_to_graph(ops.get_default_graph())
Esempio n. 35
0
  def __init__(self, input_dataset, initial_state, scan_func):
    """See `scan()` for details."""
    super(_ScanDataset, self).__init__()
    self._input_dataset = input_dataset

    with ops.name_scope("initial_state"):
      self._initial_state = nest.pack_sequence_as(initial_state, [
          ops.convert_to_tensor(t, name="component_%d" % i)
          for i, t in enumerate(nest.flatten(initial_state))
      ])

    # Compute initial values for the state shapes and types based on
    # the initial state. These will be refined by running
    # `tf_scan_func` one or more times below.
    # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor.
    self._state_shapes = nest.pack_sequence_as(
        self._initial_state,
        [t.shape for t in nest.flatten(self._initial_state)])
    self._state_types = nest.pack_sequence_as(
        self._initial_state,
        [t.dtype for t in nest.flatten(self._initial_state)])

    # Will be populated by calling `tf_scan_func`.
    self._output_classes = None
    self._output_shapes = None
    self._output_types = None

    # Iteratively rerun the scan function until reaching a fixed pont on
    # `self._state_shapes`.
    need_to_rerun = True
    while need_to_rerun:

      flat_state_shapes = nest.flatten(self._state_shapes)
      flat_state_types = nest.flatten(self._state_types)

      # Create a list in which `tf_scan_func` will store the s
      flat_new_state_shapes = []

      @function.Defun(*(flat_state_types + nest.flatten(
          sparse.as_dense_types(input_dataset.output_types,
                                input_dataset.output_classes))))
      def tf_scan_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        # Pass in shape information from the state and input_dataset.
        # TODO(b/69424092): Check that neither inputs nor outputs are sparse.
        dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                              input_dataset.output_classes)
        for arg, shape in zip(args,
                              flat_state_shapes + nest.flatten(dense_shapes)):
          arg.set_shape(shape)

        pivot = len(flat_state_shapes)
        old_state = nest.pack_sequence_as(self._initial_state, args[:pivot])
        input_value = nest.pack_sequence_as(input_dataset.output_types,
                                            args[pivot:])

        ret = scan_func(old_state, input_value)
        if not isinstance(ret, collections.Sequence) or len(ret) != 2:
          raise TypeError("The scan function must return a pair comprising the "
                          "new state and the output value.")
        new_state, output_value = ret

        flat_new_state = [
            ops.convert_to_tensor(t) for t in nest.flatten(new_state)
        ]
        flat_output_value = [
            ops.convert_to_tensor(t) for t in nest.flatten(output_value)
        ]

        # Extract shape information from the returned values.
        flat_new_state_shapes.extend([t.shape for t in flat_new_state])
        self._output_shapes = nest.pack_sequence_as(
            output_value, [t.shape for t in flat_output_value])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(flat_new_state, flat_state_types):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types, nest.pack_sequence_as(
                    self._state_types, [t.dtype for t in flat_new_state])))
        self._output_classes = nest.pack_sequence_as(
            output_value, [ops.Tensor for _ in flat_output_value])
        self._output_types = nest.pack_sequence_as(
            output_value, [t.dtype for t in flat_output_value])

        return flat_new_state + flat_output_value

      # Use the private method that will execute `tf_scan_func` but delay
      # adding it to the graph in case we need to rerun the function.
      tf_scan_func._create_definition_if_needed()  # pylint: disable=protected-access

      weakened_state_shapes = [
          original.most_specific_compatible_shape(new)
          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
      ]

      need_to_rerun = False
      for original_shape, weakened_shape in zip(flat_state_shapes,
                                                weakened_state_shapes):
        if original_shape.ndims is not None and (
            weakened_shape.ndims is None or
            original_shape.as_list() != weakened_shape.as_list()):
          need_to_rerun = True
          break

      if need_to_rerun:
        # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun
        # `tf_scan_func`.
        self._state_shapes = nest.pack_sequence_as(self._state_shapes,
                                                   weakened_state_shapes)

    self._scan_func = tf_scan_func
  def __init__(self,
               dataset,
               devices,
               max_buffer_size=1,
               prefetch_buffer_size=1,
               source_device="/cpu:0"):
    """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device
        to prefetch into.
      source_device: The host device to place the `dataset` on.

    Raises:
      RuntimeError: If run in Eager mode.
    """
    if context.executing_eagerly():
      # TODO(rohanj): Fix this. Tracking bug: b/116467184
      raise RuntimeError("MultiDeviceIterator is not currently supported in "
                         "Eager mode.")
    self._dataset = dataset._apply_options()  # pylint: disable=protected-access
    self._devices = devices
    self._source_device = source_device
    self._source_device_tensor = ops.convert_to_tensor(source_device)

    self._flat_output_shapes = nest.flatten(
        sparse.as_dense_shapes(self._dataset.output_shapes,
                               self._dataset.output_classes))
    self._flat_output_types = nest.flatten(
        sparse.as_dense_types(self._dataset.output_types,
                              self._dataset.output_classes))

    # Create the MultiDeviceIterator.
    with ops.device(self._source_device):
      self._multi_device_iterator_resource = (
          gen_dataset_ops.multi_device_iterator(
              devices=self._devices,
              shared_name="",
              container="",
              output_types=self._flat_output_types,
              output_shapes=self._flat_output_shapes))

      # The incarnation ID is used to ensure consistency between the per-device
      # iterators and the multi-device iterator.
      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
          self._dataset._as_variant_tensor(),  # pylint: disable=protected-access
          self._multi_device_iterator_resource,
          max_buffer_size=max_buffer_size)

    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
    # initialize the device side of the pipeline. This would allow the
    # MultiDeviceIterator to choose, for example, to move some transformations
    # into the device side from its input. It might be useful in rewriting.
    # Create the per device iterators.
    self._device_iterators = []
    for i, device in enumerate(self._devices):
      ds = _PerDeviceGenerator(
          i, self._multi_device_iterator_resource, self._incarnation_id,
          self._source_device_tensor, device, self._dataset.output_shapes,
          self._dataset.output_types, self._dataset.output_classes)
      if prefetch_buffer_size > 0:
        ds = ds.prefetch(prefetch_buffer_size)
      with ops.device(device):
        self._device_iterators.append(ds.make_initializable_iterator())

    device_iterator_initializers = [
        iterator.initializer for iterator in self._device_iterators
    ]
    self._initializer = control_flow_ops.group(*device_iterator_initializers)
Esempio n. 37
0
    def __init__(self, dataset):
        """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

        if not context.in_eager_mode():
            raise RuntimeError(
                "{} objects can only be used when eager execution is enabled, use "
                "tf.data.Dataset.make_iterator or "
                "tf.data.Dataset.make_one_shot_iterator for graph construction"
                .format(type(self)))
        with ops.device("/device:CPU:0"):
            ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
            self._output_classes = dataset.output_classes
            self._output_types = dataset.output_types
            self._output_shapes = dataset.output_shapes
            self._flat_output_types = nest.flatten(
                sparse.as_dense_types(self._output_types,
                                      self._output_classes))
            self._flat_output_shapes = nest.flatten(
                sparse.as_dense_shapes(self._output_shapes,
                                       self._output_classes))
            self._resource = gen_dataset_ops.iterator(
                shared_name="",
                container=_generate_shared_name("eageriterator"),
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            gen_dataset_ops.make_iterator(ds_variant, self._resource)
            # Delete the resource when this object is deleted
            self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
                handle=self._resource, handle_device="/device:CPU:0")
        self._device = context.context().device_name
        self._buffer_resource_handle = None
        if not context.context().device_spec.device_type:
            is_remote_device = False
        else:
            is_remote_device = context.context(
            ).device_spec.device_type != "CPU"
        if is_remote_device:
            with ops.device("/device:CPU:0"):
                iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
                    self._resource)

                @function.Defun(dtypes.string)
                def remote_fn(h):
                    remote_iterator = iterator_ops.Iterator.from_string_handle(
                        h, self._output_types, self._output_shapes)
                    return remote_iterator.get_next()

                remote_fn.add_to_graph(None)
                target = constant_op.constant("/device:CPU:0")
            with ops.device(self._device):
                self._buffer_resource_handle = prefetching_ops.function_buffering_resource(  # pylint: disable=line-too-long
                    string_arg=iter_string_handle,
                    f=remote_fn,
                    target_device=target,
                    buffer_size=10,
                    thread_pool_size=1,
                    container="",
                    shared_name=_generate_shared_name(
                        "function_buffer_resource"))
                self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(  # pylint: disable=line-too-long
                    handle=self._buffer_resource_handle,
                    handle_device=self._device)
Esempio n. 38
0
  def from_structure(output_types,
                     output_shapes=None,
                     shared_name=None,
                     output_classes=None):
    """Creates a new, uninitialized `Iterator` with the given structure.

    This iterator-constructing method can be used to create an iterator that
    is reusable with many different datasets.

    The returned iterator is not bound to a particular dataset, and it has
    no `initializer`. To initialize the iterator, run the operation returned by
    `Iterator.make_initializer(dataset)`.

    The following is an example

    ```python
    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

    dataset_range = Dataset.range(10)
    range_initializer = iterator.make_initializer(dataset_range)

    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
    evens_initializer = iterator.make_initializer(dataset_evens)

    # Define a model based on the iterator; in this example, the model_fn
    # is expected to take scalar tf.int64 Tensors as input (see
    # the definition of 'iterator' above).
    prediction, loss = model_fn(iterator.get_next())

    # Train for `num_epochs`, where for each epoch, we first iterate over
    # dataset_range, and then iterate over dataset_evens.
    for _ in range(num_epochs):
      # Initialize the iterator to `dataset_range`
      sess.run(range_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break

      # Initialize the iterator to `dataset_evens`
      sess.run(evens_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break
    ```

    Args:
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      shared_name: (Optional.) If non-empty, this iterator will be shared under
        the given name across multiple sessions that share the same devices
        (e.g. when using a remote server).
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.

    Raises:
      TypeError: If the structures of `output_shapes` and `output_types` are
        not the same.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    if shared_name is None:
      shared_name = ""
    if compat.forward_compatible(2018, 8, 3):
      if _device_stack_is_empty():
        with ops.device("/cpu:0"):
          iterator_resource = gen_dataset_ops.iterator_v2(
              container="",
              shared_name=shared_name,
              output_types=nest.flatten(
                  sparse.as_dense_types(output_types, output_classes)),
              output_shapes=nest.flatten(
                  sparse.as_dense_shapes(output_shapes, output_classes)))
      else:
        iterator_resource = gen_dataset_ops.iterator_v2(
            container="",
            shared_name=shared_name,
            output_types=nest.flatten(
                sparse.as_dense_types(output_types, output_classes)),
            output_shapes=nest.flatten(
                sparse.as_dense_shapes(output_shapes, output_classes)))
    else:
      iterator_resource = gen_dataset_ops.iterator(
          container="",
          shared_name=shared_name,
          output_types=nest.flatten(
              sparse.as_dense_types(output_types, output_classes)),
          output_shapes=nest.flatten(
              sparse.as_dense_shapes(output_shapes, output_classes)))
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
Esempio n. 39
0
  def from_string_handle(string_handle,
                         output_types,
                         output_shapes=None,
                         output_classes=None):
    """Creates a new, uninitialized `Iterator` based on the given handle.

    This method allows you to define a "feedable" iterator where you can choose
    between concrete iterators by feeding a value in a `tf.Session.run` call.
    In that case, `string_handle` would be a `tf.placeholder`, and you would
    feed it with the value of `tf.data.Iterator.string_handle` in each step.

    For example, if you had two iterators that marked the current position in
    a training dataset and a test dataset, you could choose which to use in
    each step as follows:

    ```python
    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    train_iterator_handle = sess.run(train_iterator.string_handle())

    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    test_iterator_handle = sess.run(test_iterator.string_handle())

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_iterator.output_types)

    next_element = iterator.get_next()
    loss = f(next_element)

    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
    ```

    Args:
      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
        to a handle produced by the `Iterator.string_handle()` method.
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
    if compat.forward_compatible(2018, 8, 3):
      if _device_stack_is_empty():
        with ops.device("/cpu:0"):
          iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
              string_handle,
              output_types=nest.flatten(
                  sparse.as_dense_types(output_types, output_classes)),
              output_shapes=nest.flatten(
                  sparse.as_dense_shapes(output_shapes, output_classes)))
      else:
        iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
            string_handle,
            output_types=nest.flatten(
                sparse.as_dense_types(output_types, output_classes)),
            output_shapes=nest.flatten(
                sparse.as_dense_shapes(output_shapes, output_classes)))
    else:
      iterator_resource = gen_dataset_ops.iterator_from_string_handle(
          string_handle,
          output_types=nest.flatten(
              sparse.as_dense_types(output_types, output_classes)),
          output_shapes=nest.flatten(
              sparse.as_dense_shapes(output_shapes, output_classes)))
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
Esempio n. 40
0
  def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s representing the next element.

    In graph mode, you should typically call this method *once* and use its
    result as the input to another computation. A typical loop will then call
    `tf.Session.run` on the result of that computation. The loop will terminate
    when the `Iterator.get_next()` operation raises
    `tf.errors.OutOfRangeError`. The following skeleton shows how to use
    this method when building a training loop:

    ```python
    dataset = ...  # A `tf.data.Dataset` object.
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    # Build a TensorFlow graph that does something with each element.
    loss = model_function(next_element)
    optimizer = ...  # A `tf.train.Optimizer` object.
    train_op = optimizer.minimize(loss)

    with tf.Session() as sess:
      try:
        while True:
          sess.run(train_op)
      except tf.errors.OutOfRangeError:
        pass
    ```

    NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
    when you are distributing different elements to multiple devices in a single
    step. However, a common pitfall arises when users call `Iterator.get_next()`
    in each iteration of their training loop. `Iterator.get_next()` adds ops to
    the graph, and executing each op allocates resources (including threads); as
    a consequence, invoking it in every iteration of a training loop causes
    slowdown and eventual resource exhaustion. To guard against this outcome, we
    log a warning when the number of uses crosses a fixed threshold of
    suspiciousness.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    self._get_next_call_count += 1
    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.as_dense_types(
                                          self._output_types,
                                          self._output_classes)),
                                  output_shapes=nest.flatten(
                                      sparse.as_dense_shapes(
                                          self._output_shapes,
                                          self._output_classes)),
                                  name=name)), self._output_types,
        self._output_shapes, self._output_classes)
    def __init__(self, shard_num, multi_device_iterator_resource,
                 incarnation_id, source_device, target_device, output_shapes,
                 output_types, output_classes):
        self._target_device = target_device
        self._output_types = output_types
        self._output_shapes = output_shapes
        self._output_classes = output_classes
        self._flat_output_shapes = nest.flatten(
            sparse.as_dense_shapes(self._output_shapes, self._output_classes))
        self._flat_output_types = nest.flatten(
            sparse.as_dense_types(self._output_types, self._output_classes))

        multi_device_iterator_string_handle = (
            gen_dataset_ops.multi_device_iterator_to_string_handle(
                multi_device_iterator_resource))

        @function.Defun()
        def _init_func():
            return multi_device_iterator_string_handle

        @function.Defun()
        def _remote_init_func():
            return functional_ops.remote_call(target=source_device,
                                              args=_init_func.captured_inputs,
                                              Tout=[dtypes.string],
                                              f=_init_func)

        self._init_func = _remote_init_func
        self._init_captured_args = _remote_init_func.captured_inputs

        @function.Defun(dtypes.string)
        def _next_func(string_handle):
            multi_device_iterator = (
                gen_dataset_ops.multi_device_iterator_from_string_handle(
                    string_handle=string_handle,
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes))
            return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
                multi_device_iterator=multi_device_iterator,
                shard_num=shard_num,
                incarnation_id=incarnation_id,
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)

        @function.Defun(dtypes.string)
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(target=source_device,
                                              args=[string_handle] +
                                              _next_func.captured_inputs,
                                              Tout=self._flat_output_types,
                                              f=_next_func)

        self._next_func = _remote_next_func
        self._next_captured_args = _remote_next_func.captured_inputs

        @function.Defun(dtypes.string)
        def _finalize_func(unused_string_handle):
            return array_ops.constant(0, dtypes.int64)

        @function.Defun(dtypes.string)
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(target=source_device,
                                              args=[string_handle] +
                                              _finalize_func.captured_inputs,
                                              Tout=[dtypes.int64],
                                              f=_finalize_func)

        self._finalize_func = _remote_finalize_func
        self._finalize_captured_args = _remote_finalize_func.captured_inputs
Esempio n. 42
0
    def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
        """Constructs a _CopyToDeviceDataset.

    Args:
      input_dataset: `Dataset` to be copied
      target_device: The name of the device to which elements would be copied.
      source_device: Device where input_dataset would be placed.
    """
        super(_CopyToDeviceDataset, self).__init__(input_dataset)
        self._input_dataset = input_dataset
        self._target_device = target_device
        spec = framework_device.DeviceSpec().from_string(self._target_device)
        self._is_gpu_target = (spec.device_type == "GPU")
        self._source_device_string = source_device
        self._source_device = ops.convert_to_tensor(source_device)

        self._flat_output_shapes = nest.flatten(
            sparse.as_dense_shapes(self._input_dataset.output_shapes,
                                   self._input_dataset.output_classes))
        self._flat_output_types = nest.flatten(
            sparse.as_dense_types(self._input_dataset.output_types,
                                  self._input_dataset.output_classes))

        @function.defun()
        def _init_func():
            """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
            # pylint: disable=protected-access
            ds_variant = self._input_dataset._as_variant_tensor()
            resource = gen_dataset_ops.anonymous_iterator(
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            with ops.control_dependencies(
                [gen_dataset_ops.make_iterator(ds_variant, resource)]):
                return gen_dataset_ops.iterator_to_string_handle(resource)

        init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun()
        def _remote_init_func():
            return functional_ops.remote_call(
                target=self._source_device,
                args=init_func_concrete.captured_inputs,
                Tout=[dtypes.string],
                f=init_func_concrete)

        self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._init_captured_args = self._init_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _next_func(string_handle):
            """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
            with ops.device(self._source_device_string):
                iterator = iterator_ops.Iterator.from_string_handle(
                    string_handle, self.output_types, self.output_shapes,
                    self.output_classes)
            ret = iterator.get_next()
            return nest.flatten(sparse.serialize_sparse_tensors(ret))

        next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(
                target=self._source_device,
                args=[string_handle] + next_func_concrete.captured_inputs,
                Tout=self._flat_output_types,
                f=next_func_concrete)

        self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._next_captured_args = self._next_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _finalize_func(string_handle):
            """Destroys the iterator resource created.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        Tensor constant 0
      """
            iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
                string_handle,
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            with ops.control_dependencies([
                    resource_variable_ops.destroy_resource_op(
                        iterator_resource, ignore_lookup_error=True)
            ]):
                return array_ops.constant(0, dtypes.int64)

        finalize_func_concrete = _finalize_func._get_concrete_function_internal(
        )  # pylint: disable=protected-access

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(
                target=self._source_device,
                args=[string_handle] + finalize_func_concrete.captured_inputs,
                Tout=[dtypes.int64],
                f=finalize_func_concrete)

        self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
        )
        self._finalize_captured_args = self._finalize_func.captured_inputs

        g = ops.get_default_graph()
        self._init_func.add_to_graph(g)
        self._next_func.add_to_graph(g)
        self._finalize_func.add_to_graph(g)