示例#1
0
      def tf_reduce_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        for arg, shape in zip(
            args,
            nest.flatten(
                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
            + nest.flatten(
                sparse.as_dense_shapes(input_dataset.output_shapes,
                                       input_dataset.output_classes))):
          arg.set_shape(shape)

        pivot = len(nest.flatten(self._state_shapes))
        nested_state_args = nest.pack_sequence_as(self._state_types,
                                                  args[:pivot])
        nested_state_args = sparse.deserialize_sparse_tensors(
            nested_state_args, self._state_types, self._state_shapes,
            self._state_classes)
        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
                                                  args[pivot:])
        nested_input_args = sparse.deserialize_sparse_tensors(
            nested_input_args, input_dataset.output_types,
            input_dataset.output_shapes, input_dataset.output_classes)

        ret = reduce_func(nested_state_args, nested_input_args)

        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
        # values to tensors.
        ret = nest.pack_sequence_as(ret, [
            sparse_tensor.SparseTensor.from_value(t)
            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
            for t in nest.flatten(ret)
        ])

        # Extract shape information from the returned values.
        flat_new_state = nest.flatten(ret)
        flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types,
                 nest.pack_sequence_as(self._state_types,
                                       [t.dtype for t in flat_new_state])))

        dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

        # Serialize any sparse tensors.
        ret = nest.pack_sequence_as(
            ret,
            [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
        return nest.flatten(ret)
示例#2
0
def get_next_as_optional(iterator):
  """Returns an `Optional` that contains the next value from the iterator.

  If `iterator` has reached the end of the sequence, the returned `Optional`
  will have no value.

  Args:
    iterator: A `tf.data.Iterator` object.

  Returns:
    An `Optional` object representing the next value from the iterator (if it
    has one) or no value.
  """
  # pylint: disable=protected-access
  return optional_ops._OptionalImpl(
      gen_dataset_ops.iterator_get_next_as_optional(
          iterator._iterator_resource,
          output_types=nest.flatten(
              sparse.as_dense_types(iterator.output_types,
                                    iterator.output_classes)),
          output_shapes=nest.flatten(
              sparse.as_dense_shapes(iterator.output_shapes,
                                     iterator.output_classes))),
      structure.Structure._from_legacy_structure(iterator.output_types,
                                                 iterator.output_shapes,
                                                 iterator.output_classes))
示例#3
0
    def tf_map_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      if dataset_ops._should_unpack_args(nested_args):  # pylint: disable=protected-access
        dataset = map_func(*nested_args)
      else:
        dataset = map_func(nested_args)

      if not isinstance(dataset, dataset_ops.Dataset):
        raise TypeError("`map_func` must return a `Dataset` object.")

      self._output_classes = dataset.output_classes
      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes

      return dataset._as_variant_tensor()  # pylint: disable=protected-access
示例#4
0
  def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    self._get_next_call_count += 1
    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.as_dense_types(
                                          self._output_types,
                                          self._output_classes)),
                                  output_shapes=nest.flatten(
                                      sparse.as_dense_shapes(
                                          self._output_shapes,
                                          self._output_classes)),
                                  name=name)), self._output_types,
        self._output_shapes, self._output_classes)
示例#5
0
    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
示例#6
0
    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
示例#7
0
    def tf_key_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      # pylint: disable=protected-access
      if dataset_ops._should_unpack_args(nested_args):
        ret = key_func(*nested_args)
      # pylint: enable=protected-access
      else:
        ret = key_func(nested_args)
      ret = ops.convert_to_tensor(ret)
      if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar():
        raise ValueError(
            "`key_func` must return a single tf.int64 tensor. "
            "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access
      return ret
示例#8
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.ignore_errors_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)))
示例#9
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.random_dataset(
       seed=self._seed,
       seed2=self._seed2,
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)))
示例#10
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.set_stats_aggregator_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._stats_aggregator._resource,  # pylint: disable=protected-access
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
示例#11
0
 def _as_variant_tensor(self):
   return self._op_function(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._tag,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
示例#12
0
 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   return gen_dataset_ops.directed_interleave_dataset(
       self._selector_input._as_variant_tensor(),
       [data_input._as_variant_tensor() for data_input in self._data_inputs],
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)))
示例#13
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.slide_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       window_size=self._window_size,
       stride=self._stride,
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)))
示例#14
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.dense_to_sparse_batch_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._batch_size,
       row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape),  # pylint: disable=protected-access
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)))
示例#15
0
 def _as_variant_tensor(self):
   input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
   return gen_dataset_ops.scan_dataset(
       input_t,
       nest.flatten(self._initial_state),
       self._scan_func.captured_inputs,
       f=self._scan_func,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
示例#16
0
  def __init__(self,
               dataset,
               devices,
               prefetch_buffer_size=1,
               source_device="/cpu:0"):
    self._dataset = dataset
    self._devices = devices
    self._source_device = source_device
    self._source_device_tensor = ops.convert_to_tensor(source_device)

    self._flat_output_shapes = nest.flatten(
        sparse.as_dense_shapes(self._dataset.output_shapes,
                               self._dataset.output_classes))
    self._flat_output_types = nest.flatten(
        sparse.as_dense_types(self._dataset.output_types,
                              self._dataset.output_classes))

    # Create the MultiDeviceIterator.
    with ops.device(self._source_device):
      self._multi_device_iterator_resource = (
          gen_dataset_ops.multi_device_iterator(
              devices=self._devices,
              shared_name="",
              container="",
              output_types=self._flat_output_types,
              output_shapes=self._flat_output_shapes))

      # The incarnation ID is used to ensure consistency between the per-device
      # iterators and the multi-device iterator.
      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
          self._dataset._as_variant_tensor(),  # pylint: disable=protected-access
          self._multi_device_iterator_resource)

    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
    # initialize the device side of the pipeline. This would allow the
    # MultiDeviceIterator to choose, for example, to move some transformations
    # into the device side from its input. It might be useful in rewriting.
    # Create the per device iterators.
    self._device_iterators = []
    i = 0
    for device in self._devices:
      ds = _PerDeviceGenerator(
          i, self._multi_device_iterator_resource, self._incarnation_id,
          self._source_device_tensor, device, self._dataset.output_shapes,
          self._dataset.output_types, self._dataset.output_classes)
      ds = ds.prefetch(prefetch_buffer_size)
      with ops.device(device):
        self._device_iterators.append(ds.make_initializable_iterator())
      i += 1

    device_iterator_initializers = [
        iterator.initializer for iterator in self._device_iterators
    ]
    self._initializer = control_flow_ops.group(*device_iterator_initializers)
示例#17
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.parallel_interleave_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._map_func.captured_inputs,
       self._cycle_length,
       self._block_length,
       self._sloppy,
       f=self._map_func,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   return gen_dataset_ops.prepend_from_queue_and_padded_batch_dataset(
       self._input_dataset._as_variant_tensor(),
       batch_size=self._batch_size,
       padded_shapes=[
           ops.convert_to_tensor(s, dtype=dtypes.int64)
           for s in nest.flatten(self._padded_shapes)
       ],
       padding_values=nest.flatten(self._padding_values),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
示例#19
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.group_by_window_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._key_func.captured_inputs,
       self._reduce_func.captured_inputs,
       self._window_size_func.captured_inputs,
       key_func=self._key_func,
       reduce_func=self._reduce_func,
       window_size_func=self._window_size_func,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
示例#20
0
 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   input_resource = self._input_dataset._as_variant_tensor()
   return gen_dataset_ops.map_and_batch_dataset(
       input_resource,
       self._map_func.captured_inputs,
       f=self._map_func,
       batch_size=self._batch_size,
       num_parallel_batches=self._num_parallel_batches,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
示例#21
0
 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   input_resource = self._input_dataset._as_variant_tensor()
   return gen_dataset_ops.shuffle_and_repeat_dataset(
       input_resource,
       buffer_size=self._buffer_size,
       count=self._count,
       seed=self._seed,
       seed2=self._seed2,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
def get_single_element(dataset):
  """Returns the single element in `dataset` as a nested structure of tensors.

  This function enables you to use a @{tf.data.Dataset} in a stateless
  "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}.
  This can be useful when your preprocessing transformations are expressed
  as a `Dataset`, and you want to use the transformation at serving time.
  For example:

  ```python
  input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE])

  def preprocessing_fn(input_str):
    # ...
    return image, label

  dataset = (tf.data.Dataset.from_tensor_slices(input_batch)
             .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
             .batch(BATCH_SIZE))

  image_batch, label_batch = tf.contrib.data.get_single_element(dataset)
  ```

  Args:
    dataset: A @{tf.data.Dataset} object containing a single element.

  Returns:
    A nested structure of @{tf.Tensor} objects, corresponding to the single
    element of `dataset`.

  Raises:
    TypeError: if `dataset` is not a `tf.data.Dataset` object.
    InvalidArgumentError (at runtime): if `dataset` does not contain exactly
      one element.
  """
  if not isinstance(dataset, dataset_ops.Dataset):
    raise TypeError("`dataset` must be a `tf.data.Dataset` object.")

  nested_ret = nest.pack_sequence_as(
      dataset.output_types, gen_dataset_ops.dataset_to_single_element(
          dataset._as_variant_tensor(),  # pylint: disable=protected-access
          output_types=nest.flatten(sparse.as_dense_types(
              dataset.output_types, dataset.output_classes)),
          output_shapes=nest.flatten(sparse.as_dense_shapes(
              dataset.output_shapes, dataset.output_classes))))
  return sparse.deserialize_sparse_tensors(
      nested_ret, dataset.output_types, dataset.output_shapes,
      dataset.output_classes)
示例#23
0
      def tf_scan_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        # Pass in shape information from the state and input_dataset.
        # TODO(b/69424092): Check that neither inputs nor outputs are sparse.
        dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                              input_dataset.output_classes)
        for arg, shape in zip(args,
                              flat_state_shapes + nest.flatten(dense_shapes)):
          arg.set_shape(shape)

        pivot = len(flat_state_shapes)
        old_state = nest.pack_sequence_as(self._initial_state, args[:pivot])
        input_value = nest.pack_sequence_as(input_dataset.output_types,
                                            args[pivot:])

        ret = scan_func(old_state, input_value)
        if not isinstance(ret, collections.Sequence) or len(ret) != 2:
          raise TypeError("The scan function must return a pair comprising the "
                          "new state and the output value.")
        new_state, output_value = ret

        flat_new_state = [
            ops.convert_to_tensor(t) for t in nest.flatten(new_state)
        ]
        flat_output_value = [
            ops.convert_to_tensor(t) for t in nest.flatten(output_value)
        ]

        # Extract shape information from the returned values.
        flat_new_state_shapes.extend([t.shape for t in flat_new_state])
        self._output_shapes = nest.pack_sequence_as(
            output_value, [t.shape for t in flat_output_value])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(flat_new_state, flat_state_types):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types, nest.pack_sequence_as(
                    self._state_types, [t.dtype for t in flat_new_state])))
        self._output_classes = nest.pack_sequence_as(
            output_value, [ops.Tensor for _ in flat_output_value])
        self._output_types = nest.pack_sequence_as(
            output_value, [t.dtype for t in flat_output_value])

        return flat_new_state + flat_output_value
示例#24
0
    def __init__(self, dataset):
        """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

        if not context.executing_eagerly():
            raise RuntimeError(
                "{} objects can only be used when eager execution is enabled, use "
                "tf.data.Dataset.make_initializable_iterator or "
                "tf.data.Dataset.make_one_shot_iterator for graph construction"
                .format(type(self)))
        with ops.device("/device:CPU:0"):
            ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
            self._output_classes = dataset.output_classes
            self._output_types = dataset.output_types
            self._output_shapes = dataset.output_shapes
            self._flat_output_types = nest.flatten(
                sparse.as_dense_types(self._output_types,
                                      self._output_classes))
            self._flat_output_shapes = nest.flatten(
                sparse.as_dense_shapes(self._output_shapes,
                                       self._output_classes))
            self._resource = gen_dataset_ops.iterator(
                shared_name="",
                container=_generate_shared_name("eageriterator"),
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            gen_dataset_ops.make_iterator(ds_variant, self._resource)
            # Delete the resource when this object is deleted
            self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
                handle=self._resource, handle_device="/device:CPU:0")
        self._device = context.context().device_name
示例#25
0
      def tf_scan_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        # Pass in shape information from the state and input_dataset.
        # TODO(b/69424092): Check that neither inputs nor outputs are sparse.
        dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                              input_dataset.output_classes)
        for arg, shape in zip(args,
                              flat_state_shapes + nest.flatten(dense_shapes)):
          arg.set_shape(shape)

        pivot = len(flat_state_shapes)
        old_state = nest.pack_sequence_as(self._initial_state, args[:pivot])
        input_value = nest.pack_sequence_as(input_dataset.output_types,
                                            args[pivot:])

        ret = scan_func(old_state, input_value)
        if not isinstance(ret, collections.Sequence) or len(ret) != 2:
          raise TypeError("The scan function must return a pair comprising the "
                          "new state and the output value.")
        new_state, output_value = ret

        flat_new_state = [
            ops.convert_to_tensor(t) for t in nest.flatten(new_state)
        ]
        flat_output_value = [
            ops.convert_to_tensor(t) for t in nest.flatten(output_value)
        ]

        # Extract shape information from the returned values.
        flat_new_state_shapes.extend([t.shape for t in flat_new_state])
        self._output_shapes = nest.pack_sequence_as(
            output_value, [t.shape for t in flat_output_value])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(flat_new_state, flat_state_types):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types, nest.pack_sequence_as(
                    self._state_types, [t.dtype for t in flat_new_state])))
        self._output_classes = nest.pack_sequence_as(
            output_value, [ops.Tensor for _ in flat_output_value])
        self._output_types = nest.pack_sequence_as(
            output_value, [t.dtype for t in flat_output_value])

        return flat_new_state + flat_output_value
示例#26
0
  def __init__(self, dataset):
    """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

    if not context.executing_eagerly():
      raise RuntimeError(
          "{} objects can only be used when eager execution is enabled, use "
          "tf.data.Dataset.make_initializable_iterator or "
          "tf.data.Dataset.make_one_shot_iterator for graph construction".
          format(type(self)))
    with ops.device("/device:CPU:0"):
      ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
      self._output_classes = dataset.output_classes
      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes
      self._flat_output_types = nest.flatten(
          sparse.as_dense_types(self._output_types, self._output_classes))
      self._flat_output_shapes = nest.flatten(
          sparse.as_dense_shapes(self._output_shapes, self._output_classes))
      self._resource = gen_dataset_ops.iterator(
          shared_name="",
          container=_generate_shared_name("eageriterator"),
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)
      gen_dataset_ops.make_iterator(ds_variant, self._resource)
      # Delete the resource when this object is deleted
      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
          handle=self._resource, handle_device="/device:CPU:0")
    self._device = context.context().device_name
示例#27
0
 def get_value(self, name=None):
   # TODO(b/110122868): Consolidate the restructuring logic with similar logic
   # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
   with ops.name_scope(name, "OptionalGetValue",
                       [self._variant_tensor]) as scope:
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(
             self._output_types,
             gen_dataset_ops.optional_get_value(
                 self._variant_tensor,
                 name=scope,
                 output_types=nest.flatten(
                     sparse.as_dense_types(self._output_types,
                                           self._output_classes)),
                 output_shapes=nest.flatten(
                     sparse.as_dense_shapes(self._output_shapes,
                                            self._output_classes)))),
         self._output_types, self._output_shapes, self._output_classes)
示例#28
0
 def get_value(self, name=None):
   # TODO(b/110122868): Consolidate the restructuring logic with similar logic
   # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
   with ops.name_scope(name, "OptionalGetValue",
                       [self._variant_tensor]) as scope:
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(
             self._output_types,
             gen_dataset_ops.optional_get_value(
                 self._variant_tensor,
                 name=scope,
                 output_types=nest.flatten(
                     sparse.as_dense_types(self._output_types,
                                           self._output_classes)),
                 output_shapes=nest.flatten(
                     sparse.as_dense_shapes(self._output_shapes,
                                            self._output_classes)))),
         self._output_types, self._output_shapes, self._output_classes)
示例#29
0
 def get_next(self, name=None):
   """Returns a nested structure of `tf.Tensor`s containing the next element.
   Args:
     name: (Optional.) A name for the created operation.
   Returns:
     A nested structure of `tf.Tensor` objects.
   """
   return sparse.deserialize_sparse_tensors(
       nest.pack_sequence_as(self._output_types,
                             gen_dataset_ops.iterator_get_next(
                                 self._iterator_resource,
                                 output_types=nest.flatten(
                                     sparse.as_dense_types(
                                         self._output_types,
                                         self._output_classes)),
                                 output_shapes=nest.flatten(
                                     sparse.as_dense_shapes(
                                         self._output_shapes,
                                         self._output_classes)),
                                 name=name)), self._output_types,
       self._output_shapes, self._output_classes)
示例#30
0
    def tf_key_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      # pylint: disable=protected-access
      if dataset_ops._should_unpack_args(nested_args):
        ret = key_func(*nested_args)
      # pylint: enable=protected-access
      else:
        ret = key_func(nested_args)
      ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
      if ret.dtype != dtypes.int64:
        raise ValueError("`key_func` must return a single tf.int64 tensor.")
      return ret
    def get_next(self, name=None):
        """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
        return sparse.deserialize_sparse_tensors(
            nest.pack_sequence_as(
                self._output_types,
                gen_dataset_ops.iterator_get_next(
                    self._iterator_resource,
                    output_types=nest.flatten(
                        sparse.as_dense_types(self._output_types,
                                              self._output_classes)),
                    output_shapes=nest.flatten(
                        sparse.as_dense_shapes(self._output_shapes,
                                               self._output_classes)),
                    name=name)), self._output_types, self._output_shapes,
            self._output_classes)
示例#32
0
        def tf_finalize_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            for arg, shape in zip(
                    args,
                    nest.flatten(
                        sparse.as_dense_shapes(self._state_shapes,
                                               self._state_classes))):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(self._state_types, args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, self._state_types, self._state_shapes,
                self._state_classes)

            ret = finalize_func(nested_args)

            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._output_classes = sparse.get_classes(ret)
            self._output_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._output_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            dataset_ops._warn_if_collections(
                "tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
示例#33
0
  def from_string_handle(string_handle,
                         output_types,
                         output_shapes=None,
                         output_classes=None):
    """Creates a new, uninitialized `Iterator` based on the given handle.

    This method allows you to define a "feedable" iterator where you can choose
    between concrete iterators by feeding a value in a `tf.Session.run` call.
    In that case, `string_handle` would be a `tf.placeholder`, and you would
    feed it with the value of `tf.data.Iterator.string_handle` in each step.

    For example, if you had two iterators that marked the current position in
    a training dataset and a test dataset, you could choose which to use in
    each step as follows:

    ```python
    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    train_iterator_handle = sess.run(train_iterator.string_handle())

    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    test_iterator_handle = sess.run(test_iterator.string_handle())

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_iterator.output_types)

    next_element = iterator.get_next()
    loss = f(next_element)

    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
    ```

    Args:
      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
        to a handle produced by the `Iterator.string_handle()` method.
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
    if compat.forward_compatible(2018, 8, 3):
      if _device_stack_is_empty():
        with ops.device("/cpu:0"):
          iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
              string_handle,
              output_types=nest.flatten(
                  sparse.as_dense_types(output_types, output_classes)),
              output_shapes=nest.flatten(
                  sparse.as_dense_shapes(output_shapes, output_classes)))
      else:
        iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
            string_handle,
            output_types=nest.flatten(
                sparse.as_dense_types(output_types, output_classes)),
            output_shapes=nest.flatten(
                sparse.as_dense_shapes(output_shapes, output_classes)))
    else:
      iterator_resource = gen_dataset_ops.iterator_from_string_handle(
          string_handle,
          output_types=nest.flatten(
              sparse.as_dense_types(output_types, output_classes)),
          output_shapes=nest.flatten(
              sparse.as_dense_shapes(output_shapes, output_classes)))
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
示例#34
0
  def from_structure(output_types,
                     output_shapes=None,
                     shared_name=None,
                     output_classes=None):
    """Creates a new, uninitialized `Iterator` with the given structure.

    This iterator-constructing method can be used to create an iterator that
    is reusable with many different datasets.

    The returned iterator is not bound to a particular dataset, and it has
    no `initializer`. To initialize the iterator, run the operation returned by
    `Iterator.make_initializer(dataset)`.

    The following is an example

    ```python
    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

    dataset_range = Dataset.range(10)
    range_initializer = iterator.make_initializer(dataset_range)

    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
    evens_initializer = iterator.make_initializer(dataset_evens)

    # Define a model based on the iterator; in this example, the model_fn
    # is expected to take scalar tf.int64 Tensors as input (see
    # the definition of 'iterator' above).
    prediction, loss = model_fn(iterator.get_next())

    # Train for `num_epochs`, where for each epoch, we first iterate over
    # dataset_range, and then iterate over dataset_evens.
    for _ in range(num_epochs):
      # Initialize the iterator to `dataset_range`
      sess.run(range_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break

      # Initialize the iterator to `dataset_evens`
      sess.run(evens_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break
    ```

    Args:
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      shared_name: (Optional.) If non-empty, this iterator will be shared under
        the given name across multiple sessions that share the same devices
        (e.g. when using a remote server).
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.

    Raises:
      TypeError: If the structures of `output_shapes` and `output_types` are
        not the same.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    if shared_name is None:
      shared_name = ""
    if compat.forward_compatible(2018, 8, 3):
      if _device_stack_is_empty():
        with ops.device("/cpu:0"):
          iterator_resource = gen_dataset_ops.iterator_v2(
              container="",
              shared_name=shared_name,
              output_types=nest.flatten(
                  sparse.as_dense_types(output_types, output_classes)),
              output_shapes=nest.flatten(
                  sparse.as_dense_shapes(output_shapes, output_classes)))
      else:
        iterator_resource = gen_dataset_ops.iterator_v2(
            container="",
            shared_name=shared_name,
            output_types=nest.flatten(
                sparse.as_dense_types(output_types, output_classes)),
            output_shapes=nest.flatten(
                sparse.as_dense_shapes(output_shapes, output_classes)))
    else:
      iterator_resource = gen_dataset_ops.iterator(
          container="",
          shared_name=shared_name,
          output_types=nest.flatten(
              sparse.as_dense_types(output_types, output_classes)),
          output_shapes=nest.flatten(
              sparse.as_dense_shapes(output_shapes, output_classes)))
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
示例#35
0
            def tf_reduce_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = reduce_func(nested_state_args, nested_input_args)

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])

                # Extract shape information from the returned values.
                flat_new_state = nest.flatten(ret)
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in flat_new_state])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(flat_new_state,
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in flat_new_state])))

                dataset_ops._warn_if_collections(
                    "tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

                # Serialize any sparse tensors.
                ret = nest.pack_sequence_as(ret, [
                    t
                    for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
                ])
                return nest.flatten(ret)
示例#36
0
 def testAsDenseShapes(self, types_fn, classes_fn, expected_fn):
     types = types_fn()
     classes = classes_fn()
     expected = expected_fn()
     self.assertShapesEqual(sparse.as_dense_shapes(types, classes),
                            expected)
示例#37
0
            def tf_scan_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                # Pass in shape information from the state and input_dataset.
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                print(self._state_classes)
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                print(input_dataset.output_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = scan_func(nested_state_args, nested_input_args)
                if not isinstance(ret, collections.Sequence) or len(ret) != 2:
                    raise TypeError(
                        "The scan function must return a pair comprising the "
                        "new state and the output value.")

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])
                new_state, output_value = ret

                # Extract and validate class information from the returned values.
                for t, clazz in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_classes)):
                    if not isinstance(t, clazz):
                        raise TypeError(
                            "The element classes for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_classes,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [type(t) for t in nest.flatten(new_state)])))
                self._output_classes = sparse.get_classes(output_value)

                # Extract shape information from the returned values.
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in nest.flatten(new_state)])
                self._output_shapes = nest.pack_sequence_as(
                    output_value,
                    [t.get_shape() for t in nest.flatten(output_value)])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in nest.flatten(new_state)])))
                self._output_types = nest.pack_sequence_as(
                    output_value,
                    [t.dtype for t in nest.flatten(output_value)])

                dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

                # Serialize any sparse tensors.
                new_state = nest.pack_sequence_as(new_state, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(new_state))
                ])
                output_value = nest.pack_sequence_as(output_value, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(output_value))
                ])
                return nest.flatten(new_state) + nest.flatten(output_value)
示例#38
0
 def testAsDenseShapes(self):
   test_cases = (
       {
           "types": (),
           "classes": (),
           "expected": ()
       },
       {
           "types": tensor_shape.scalar(),
           "classes": ops.Tensor,
           "expected": tensor_shape.scalar()
       },
       {
           "types": tensor_shape.scalar(),
           "classes": sparse_tensor.SparseTensor,
           "expected": tensor_shape.unknown_shape()
       },
       {
           "types": (tensor_shape.scalar()),
           "classes": (ops.Tensor),
           "expected": (tensor_shape.scalar())
       },
       {
           "types": (tensor_shape.scalar()),
           "classes": (sparse_tensor.SparseTensor),
           "expected": (tensor_shape.unknown_shape())
       },
       {
           "types": (tensor_shape.scalar(), ()),
           "classes": (ops.Tensor, ()),
           "expected": (tensor_shape.scalar(), ())
       },
       {
           "types": ((), tensor_shape.scalar()),
           "classes": ((), ops.Tensor),
           "expected": ((), tensor_shape.scalar())
       },
       {
           "types": (tensor_shape.scalar(), ()),
           "classes": (sparse_tensor.SparseTensor, ()),
           "expected": (tensor_shape.unknown_shape(), ())
       },
       {
           "types": ((), tensor_shape.scalar()),
           "classes": ((), sparse_tensor.SparseTensor),
           "expected": ((), tensor_shape.unknown_shape())
       },
       {
           "types": (tensor_shape.scalar(), (), tensor_shape.scalar()),
           "classes": (ops.Tensor, (), ops.Tensor),
           "expected": (tensor_shape.scalar(), (), tensor_shape.scalar())
       },
       {
           "types": (tensor_shape.scalar(), (), tensor_shape.scalar()),
           "classes": (sparse_tensor.SparseTensor, (),
                       sparse_tensor.SparseTensor),
           "expected": (tensor_shape.unknown_shape(), (),
                        tensor_shape.unknown_shape())
       },
       {
           "types": ((), tensor_shape.scalar(), ()),
           "classes": ((), ops.Tensor, ()),
           "expected": ((), tensor_shape.scalar(), ())
       },
       {
           "types": ((), tensor_shape.scalar(), ()),
           "classes": ((), sparse_tensor.SparseTensor, ()),
           "expected": ((), tensor_shape.unknown_shape(), ())
       },
   )
   for test_case in test_cases:
     self.assertShapesEqual(
         sparse.as_dense_shapes(test_case["types"], test_case["classes"]),
         test_case["expected"])
示例#39
0
  def from_string_handle(string_handle,
                         output_types,
                         output_shapes=None,
                         output_classes=None):
    """Creates a new, uninitialized `Iterator` based on the given handle.

    This method allows you to define a "feedable" iterator where you can choose
    between concrete iterators by feeding a value in a @{tf.Session.run} call.
    In that case, `string_handle` would a @{tf.placeholder}, and you would feed
    it with the value of @{tf.data.Iterator.string_handle} in each step.

    For example, if you had two iterators that marked the current position in
    a training dataset and a test dataset, you could choose which to use in
    each step as follows:

    ```python
    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    train_iterator_handle = sess.run(train_iterator.string_handle())

    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    test_iterator_handle = sess.run(test_iterator.string_handle())

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_iterator.output_types)

    next_element = iterator.get_next()
    loss = f(next_element)

    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
    ```

    Args:
      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
        to a handle produced by the `Iterator.string_handle()` method.
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
    iterator_resource = gen_dataset_ops.iterator_from_string_handle(
        string_handle,
        output_types=nest.flatten(
            sparse.as_dense_types(output_types, output_classes)),
        output_shapes=nest.flatten(
            sparse.as_dense_shapes(output_shapes, output_classes)))
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
示例#40
0
  def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
    """Constructs a _CopyToDeviceDataset.

    Args:
      input_dataset: `Dataset` to be copied
      target_device: The name of the device to which elements would be copied.
      source_device: Device where input_dataset would be placed.
    """
    super(_CopyToDeviceDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    self._target_device = target_device
    spec = framework_device.DeviceSpec().from_string(self._target_device)
    self._is_gpu_target = (spec.device_type == "GPU")
    self._source_device_string = source_device
    self._source_device = ops.convert_to_tensor(source_device)

    self._flat_output_shapes = nest.flatten(
        sparse.as_dense_shapes(self._input_dataset.output_shapes,
                               self._input_dataset.output_classes))
    self._flat_output_types = nest.flatten(
        sparse.as_dense_types(self._input_dataset.output_types,
                              self._input_dataset.output_classes))

    @function.defun()
    def _init_func():
      """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
      # pylint: disable=protected-access
      ds_variant = self._input_dataset._as_variant_tensor()
      resource = gen_dataset_ops.anonymous_iterator(
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)
      with ops.control_dependencies(
          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
        return gen_dataset_ops.iterator_to_string_handle(resource)

    init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun()
    def _remote_init_func():
      return functional_ops.remote_call(
          target=self._source_device,
          args=init_func_concrete.captured_inputs,
          Tout=[dtypes.string],
          f=init_func_concrete)

    self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
    self._init_captured_args = self._init_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _next_func(string_handle):
      """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
      with ops.device(self._source_device_string):
        iterator = iterator_ops.Iterator.from_string_handle(
            string_handle, self.output_types, self.output_shapes,
            self.output_classes)
      ret = iterator.get_next()
      return nest.flatten(sparse.serialize_sparse_tensors(ret))

    next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_next_func(string_handle):
      return functional_ops.remote_call(
          target=self._source_device,
          args=[string_handle] +
          next_func_concrete.captured_inputs,
          Tout=self._flat_output_types,
          f=next_func_concrete)

    self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
    self._next_captured_args = self._next_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _finalize_func(string_handle):
      """Destroys the iterator resource created.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        Tensor constant 0
      """
      iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
          string_handle,
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)
      with ops.control_dependencies([
          resource_variable_ops.destroy_resource_op(
              iterator_resource, ignore_lookup_error=True)]):
        return array_ops.constant(0, dtypes.int64)

    finalize_func_concrete = _finalize_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_finalize_func(string_handle):
      return functional_ops.remote_call(
          target=self._source_device,
          args=[string_handle] +
          finalize_func_concrete.captured_inputs,
          Tout=[dtypes.int64],
          f=finalize_func_concrete)

    self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
    )
    self._finalize_captured_args = self._finalize_func.captured_inputs

    g = ops.get_default_graph()
    self._init_func.add_to_graph(g)
    self._next_func.add_to_graph(g)
    self._finalize_func.add_to_graph(g)
    def __init__(self,
                 dataset,
                 devices,
                 max_buffer_size=1,
                 prefetch_buffer_size=1,
                 source_device="/cpu:0"):
        """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device
        to prefetch into.
      source_device: The host device to place the `dataset` on.

      In order to prevent deadlocks, if the prefetch_buffer_size is greater
      than the max_buffer_size, we set the max_buffer_size to
      prefetch_buffer_size.

    Raises:
      RuntimeError: If run in Eager mode.
    """
        self._dataset = dataset._apply_options()  # pylint: disable=protected-access
        self._devices = devices
        self._source_device = source_device
        self._source_device_tensor = ops.convert_to_tensor(source_device)
        self._max_buffer_size = max_buffer_size
        self._prefetch_buffer_size = prefetch_buffer_size

        if self._prefetch_buffer_size > self._max_buffer_size:
            self._max_buffer_size = self._prefetch_buffer_size

        self._flat_output_shapes = nest.flatten(
            sparse.as_dense_shapes(self._dataset.output_shapes,
                                   self._dataset.output_classes))
        self._flat_output_types = nest.flatten(
            sparse.as_dense_types(self._dataset.output_types,
                                  self._dataset.output_classes))

        # Create the MultiDeviceIterator.
        with ops.device(self._source_device):
            # TODO(b/121378567): Get rid of this shared_name hack.
            shared_name = ""
            if context.executing_eagerly():
                shared_name = context.shared_name()
            self._multi_device_iterator_resource = (
                gen_dataset_ops.multi_device_iterator(
                    devices=self._devices,
                    shared_name=shared_name,
                    container="",
                    **dataset_ops.flat_structure(self._dataset)))
            if context.executing_eagerly():
                # Delete the resource when this object is deleted
                self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
                    handle=self._multi_device_iterator_resource,
                    handle_device=self._source_device)

            # The incarnation ID is used to ensure consistency between the per-device
            # iterators and the multi-device iterator.
            self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
                self._dataset._variant_tensor,  # pylint: disable=protected-access
                self._multi_device_iterator_resource,
                max_buffer_size=self._max_buffer_size)

        self._prototype_device_datasets = []
        for i, device in enumerate(self._devices):
            with ops.device(device):
                ds = _PerDeviceGenerator(i,
                                         self._multi_device_iterator_resource,
                                         self._incarnation_id,
                                         self._source_device_tensor,
                                         self._dataset._element_structure)  # pylint: disable=protected-access
                self._prototype_device_datasets.append(ds)

        # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
        # initialize the device side of the pipeline. This would allow the
        # MultiDeviceIterator to choose, for example, to move some transformations
        # into the device side from its input. It might be useful in rewriting.
        # Create the per device iterators.
        self._device_iterators = []
        for i, device in enumerate(self._devices):
            with ops.device(device):
                ds = self._create_device_dataset(i)
                if context.executing_eagerly():
                    self._device_iterators.append(
                        dataset_ops.make_one_shot_iterator(ds))
                else:
                    self._device_iterators.append(
                        dataset_ops.make_initializable_iterator(ds))

        if not context.executing_eagerly():
            device_iterator_initializers = [
                iterator.initializer for iterator in self._device_iterators
            ]
            self._initializer = control_flow_ops.group(
                *device_iterator_initializers)
示例#42
0
  def __init__(self, dataset):
    """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

    if not context.executing_eagerly():
      raise RuntimeError(
          "{} objects can only be used when eager execution is enabled, use "
          "tf.data.Dataset.make_initializable_iterator or "
          "tf.data.Dataset.make_one_shot_iterator for graph construction".
          format(type(self)))
    with ops.device("/device:CPU:0"):
      ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
      self._output_classes = dataset.output_classes
      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes
      self._flat_output_types = nest.flatten(
          sparse.as_dense_types(self._output_types, self._output_classes))
      self._flat_output_shapes = nest.flatten(
          sparse.as_dense_shapes(self._output_shapes, self._output_classes))
      self._resource = gen_dataset_ops.iterator(
          shared_name="",
          container=_generate_shared_name("eageriterator"),
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)
      gen_dataset_ops.make_iterator(ds_variant, self._resource)
      # Delete the resource when this object is deleted
      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
          handle=self._resource, handle_device="/device:CPU:0")
    self._device = context.context().device_name
    self._buffer_resource_handle = None
    if not context.context().device_spec.device_type:
      is_remote_device = False
    else:
      is_remote_device = context.context().device_spec.device_type != "CPU"
    if is_remote_device:
      with ops.device("/device:CPU:0"):
        iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
            self._resource)

        @function.Defun(dtypes.string)
        def remote_fn(h):
          remote_iterator = iterator_ops.Iterator.from_string_handle(
              h, self._output_types, self._output_shapes)
          return remote_iterator.get_next()

        remote_fn.add_to_graph(None)
        target = constant_op.constant("/device:CPU:0")
      with ops.device(self._device):
        self._buffer_resource_handle = prefetching_ops.function_buffering_resource(  # pylint: disable=line-too-long
            string_arg=iter_string_handle,
            f=remote_fn,
            target_device=target,
            buffer_size=10,
            thread_pool_size=1,
            container="",
            shared_name=_generate_shared_name("function_buffer_resource"))
        self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(  # pylint: disable=line-too-long
            handle=self._buffer_resource_handle,
            handle_device=self._device)
示例#43
0
    def __init__(self,
                 dataset,
                 devices,
                 max_buffer_size=1,
                 prefetch_buffer_size=1,
                 source_device="/cpu:0"):
        """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device
        to prefetch into.
      source_device: The host device to place the `dataset` on.

    Raises:
      RuntimeError: If run in Eager mode.
    """
        if context.executing_eagerly():
            # TODO(rohanj): Fix this. Tracking bug: b/116467184
            raise RuntimeError(
                "MultiDeviceIterator is not currently supported in "
                "Eager mode.")
        self._dataset = dataset
        self._devices = devices
        self._source_device = source_device
        self._source_device_tensor = ops.convert_to_tensor(source_device)

        self._flat_output_shapes = nest.flatten(
            sparse.as_dense_shapes(self._dataset.output_shapes,
                                   self._dataset.output_classes))
        self._flat_output_types = nest.flatten(
            sparse.as_dense_types(self._dataset.output_types,
                                  self._dataset.output_classes))

        # Create the MultiDeviceIterator.
        with ops.device(self._source_device):
            self._multi_device_iterator_resource = (
                gen_dataset_ops.multi_device_iterator(
                    devices=self._devices,
                    shared_name="",
                    container="",
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes))

            # The incarnation ID is used to ensure consistency between the per-device
            # iterators and the multi-device iterator.
            self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
                self._dataset._as_variant_tensor(),  # pylint: disable=protected-access
                self._multi_device_iterator_resource,
                max_buffer_size=max_buffer_size)

        # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
        # initialize the device side of the pipeline. This would allow the
        # MultiDeviceIterator to choose, for example, to move some transformations
        # into the device side from its input. It might be useful in rewriting.
        # Create the per device iterators.
        self._device_iterators = []
        i = 0
        for device in self._devices:
            ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
                                     self._incarnation_id,
                                     self._source_device_tensor, device,
                                     self._dataset.output_shapes,
                                     self._dataset.output_types,
                                     self._dataset.output_classes)
            if prefetch_buffer_size > 0:
                ds = ds.prefetch(prefetch_buffer_size)
            with ops.device(device):
                self._device_iterators.append(ds.make_initializable_iterator())
            i += 1

        device_iterator_initializers = [
            iterator.initializer for iterator in self._device_iterators
        ]
        self._initializer = control_flow_ops.group(
            *device_iterator_initializers)
示例#44
0
    def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
        """Constructs a _CopyToDeviceDataset.

    Args:
      input_dataset: `Dataset` to be copied
      target_device: The name of the device to which elements would be copied.
      source_device: Device where input_dataset would be placed.
    """
        super(_CopyToDeviceDataset, self).__init__(input_dataset)
        self._input_dataset = input_dataset
        self._target_device = target_device
        spec = framework_device.DeviceSpec().from_string(self._target_device)
        self._is_gpu_target = (spec.device_type == "GPU")
        self._source_device_string = source_device
        self._source_device = ops.convert_to_tensor(source_device)

        self._flat_output_shapes = nest.flatten(
            sparse.as_dense_shapes(self._input_dataset.output_shapes,
                                   self._input_dataset.output_classes))
        self._flat_output_types = nest.flatten(
            sparse.as_dense_types(self._input_dataset.output_types,
                                  self._input_dataset.output_classes))

        @function.defun()
        def _init_func():
            """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
            # pylint: disable=protected-access
            ds_variant = self._input_dataset._as_variant_tensor()
            resource = gen_dataset_ops.anonymous_iterator(
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            with ops.control_dependencies(
                [gen_dataset_ops.make_iterator(ds_variant, resource)]):
                return gen_dataset_ops.iterator_to_string_handle(resource)

        init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun()
        def _remote_init_func():
            return functional_ops.remote_call(
                target=self._source_device,
                args=init_func_concrete.captured_inputs,
                Tout=[dtypes.string],
                f=init_func_concrete)

        self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._init_captured_args = self._init_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _next_func(string_handle):
            """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
            with ops.device(self._source_device_string):
                iterator = iterator_ops.Iterator.from_string_handle(
                    string_handle, self.output_types, self.output_shapes,
                    self.output_classes)
            ret = iterator.get_next()
            return nest.flatten(sparse.serialize_sparse_tensors(ret))

        next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(
                target=self._source_device,
                args=[string_handle] + next_func_concrete.captured_inputs,
                Tout=self._flat_output_types,
                f=next_func_concrete)

        self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._next_captured_args = self._next_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _finalize_func(string_handle):
            """Destroys the iterator resource created.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        Tensor constant 0
      """
            iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
                string_handle,
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            with ops.control_dependencies([
                    resource_variable_ops.destroy_resource_op(
                        iterator_resource, ignore_lookup_error=True)
            ]):
                return array_ops.constant(0, dtypes.int64)

        finalize_func_concrete = _finalize_func._get_concrete_function_internal(
        )  # pylint: disable=protected-access

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(
                target=self._source_device,
                args=[string_handle] + finalize_func_concrete.captured_inputs,
                Tout=[dtypes.int64],
                f=finalize_func_concrete)

        self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
        )
        self._finalize_captured_args = self._finalize_func.captured_inputs

        g = ops.get_default_graph()
        self._init_func.add_to_graph(g)
        self._next_func.add_to_graph(g)
        self._finalize_func.add_to_graph(g)
示例#45
0
  def from_structure(output_types,
                     output_shapes=None,
                     shared_name=None,
                     output_classes=None):
    """Creates a new, uninitialized `Iterator` with the given structure.

    This iterator-constructing method can be used to create an iterator that
    is reusable with many different datasets.

    The returned iterator is not bound to a particular dataset, and it has
    no `initializer`. To initialize the iterator, run the operation returned by
    `Iterator.make_initializer(dataset)`.

    The following is an example

    ```python
    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

    dataset_range = Dataset.range(10)
    range_initializer = iterator.make_initializer(dataset_range)

    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
    evens_initializer = iterator.make_initializer(dataset_evens)

    # Define a model based on the iterator; in this example, the model_fn
    # is expected to take scalar tf.int64 Tensors as input (see
    # the definition of 'iterator' above).
    prediction, loss = model_fn(iterator.get_next())

    # Train for `num_epochs`, where for each epoch, we first iterate over
    # dataset_range, and then iterate over dataset_evens.
    for _ in range(num_epochs):
      # Initialize the iterator to `dataset_range`
      sess.run(range_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break

      # Initialize the iterator to `dataset_evens`
      sess.run(evens_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break
    ```

    Args:
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      shared_name: (Optional.) If non-empty, this iterator will be shared under
        the given name across multiple sessions that share the same devices
        (e.g. when using a remote server).
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.

    Raises:
      TypeError: If the structures of `output_shapes` and `output_types` are
        not the same.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    if shared_name is None:
      shared_name = ""
    iterator_resource = gen_dataset_ops.iterator(
        container="",
        shared_name=shared_name,
        output_types=nest.flatten(
            sparse.as_dense_types(output_types, output_classes)),
        output_shapes=nest.flatten(
            sparse.as_dense_shapes(output_shapes, output_classes)))
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
示例#46
0
  def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s representing the next element.

    In graph mode, you should typically call this method *once* and use its
    result as the input to another computation. A typical loop will then call
    @{tf.Session.run} on the result of that computation. The loop will terminate
    when the `Iterator.get_next()` operation raises
    @{tf.errors.OutOfRangeError}. The following skeleton shows how to use
    this method when building a training loop:

    ```python
    dataset = ...  # A `tf.data.Dataset` object.
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    # Build a TensorFlow graph that does something with each element.
    loss = model_function(next_element)
    optimizer = ...  # A `tf.train.Optimizer` object.
    train_op = optimizer.minimize(loss)

    with tf.Session() as sess:
      try:
        while True:
          sess.run(train_op)
      except tf.errors.OutOfRangeError:
        pass
    ```

    NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
    when you are distributing different elements to multiple devices in a single
    step. However, a common pitfall arises when users call `Iterator.get_next()`
    in each iteration of their training loop. `Iterator.get_next()` adds ops to
    the graph, and executing each op allocates resources (including threads); as
    a consequence, invoking it in every iteration of a training loop causes
    slowdown and eventual resource exhaustion. To guard against this outcome, we
    log a warning when the number of uses crosses a fixed threshold of
    suspiciousness.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    self._get_next_call_count += 1
    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.as_dense_types(
                                          self._output_types,
                                          self._output_classes)),
                                  output_shapes=nest.flatten(
                                      sparse.as_dense_shapes(
                                          self._output_shapes,
                                          self._output_classes)),
                                  name=name)), self._output_types,
        self._output_shapes, self._output_classes)
示例#47
0
    def __init__(self, shard_num, multi_device_iterator_resource,
                 incarnation_id, source_device, target_device, output_shapes,
                 output_types, output_classes):
        self._target_device = target_device
        self._output_types = output_types
        self._output_shapes = output_shapes
        self._output_classes = output_classes
        self._flat_output_shapes = nest.flatten(
            sparse.as_dense_shapes(self._output_shapes, self._output_classes))
        self._flat_output_types = nest.flatten(
            sparse.as_dense_types(self._output_types, self._output_classes))

        multi_device_iterator_string_handle = (
            gen_dataset_ops.multi_device_iterator_to_string_handle(
                multi_device_iterator_resource))

        @function.Defun()
        def _init_func():
            return multi_device_iterator_string_handle

        @function.Defun()
        def _remote_init_func():
            return functional_ops.remote_call(target=source_device,
                                              args=_init_func.captured_inputs,
                                              Tout=[dtypes.string],
                                              f=_init_func)

        self._init_func = _remote_init_func
        self._init_captured_args = _remote_init_func.captured_inputs

        @function.Defun(dtypes.string)
        def _next_func(string_handle):
            multi_device_iterator = (
                gen_dataset_ops.multi_device_iterator_from_string_handle(
                    string_handle=string_handle,
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes))
            return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
                multi_device_iterator=multi_device_iterator,
                shard_num=shard_num,
                incarnation_id=incarnation_id,
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)

        @function.Defun(dtypes.string, experimental_ints_on_device=True)
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(target=source_device,
                                              args=[string_handle] +
                                              _next_func.captured_inputs,
                                              Tout=self._flat_output_types,
                                              f=_next_func)

        self._next_func = _remote_next_func
        self._next_captured_args = _remote_next_func.captured_inputs

        @function.Defun(dtypes.string)
        def _finalize_func(unused_string_handle):
            return array_ops.constant(0, dtypes.int64)

        @function.Defun(dtypes.string)
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(target=source_device,
                                              args=[string_handle] +
                                              _finalize_func.captured_inputs,
                                              Tout=[dtypes.int64],
                                              f=_finalize_func)

        self._finalize_func = _remote_finalize_func
        self._finalize_captured_args = _remote_finalize_func.captured_inputs
  def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
               source_device, target_device, output_shapes, output_types,
               output_classes):
    self._target_device = target_device
    self._output_types = output_types
    self._output_shapes = output_shapes
    self._output_classes = output_classes
    self._flat_output_shapes = nest.flatten(
        sparse.as_dense_shapes(self._output_shapes, self._output_classes))
    self._flat_output_types = nest.flatten(
        sparse.as_dense_types(self._output_types, self._output_classes))

    multi_device_iterator_string_handle = (
        gen_dataset_ops.multi_device_iterator_to_string_handle(
            multi_device_iterator_resource))

    @function.defun()
    def _init_func():
      return multi_device_iterator_string_handle

    init_func_concrete = _init_func.get_concrete_function()
    @function.defun()
    def _remote_init_func():
      return functional_ops.remote_call(
          target=source_device,
          args=init_func_concrete.captured_inputs,
          Tout=[dtypes.string],
          f=init_func_concrete)

    self._init_func = _remote_init_func.get_concrete_function()
    self._init_captured_args = self._init_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _next_func(string_handle):
      multi_device_iterator = (
          gen_dataset_ops.multi_device_iterator_from_string_handle(
              string_handle=string_handle,
              output_types=self._flat_output_types,
              output_shapes=self._flat_output_shapes))
      return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
          multi_device_iterator=multi_device_iterator,
          shard_num=shard_num,
          incarnation_id=incarnation_id,
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)

    next_func_concrete = _next_func.get_concrete_function()
    @function.defun_with_attributes(
        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
        attributes={"experimental_ints_on_device": True})
    def _remote_next_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] +
          next_func_concrete.captured_inputs,
          Tout=self._flat_output_types,
          f=next_func_concrete)

    self._next_func = _remote_next_func.get_concrete_function()
    self._next_captured_args = self._next_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _finalize_func(unused_string_handle):
      return array_ops.constant(0, dtypes.int64)

    finalize_func_concrete = _finalize_func.get_concrete_function()
    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_finalize_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] +
          finalize_func_concrete.captured_inputs,
          Tout=[dtypes.int64],
          f=finalize_func_concrete)

    self._finalize_func = _remote_finalize_func.get_concrete_function()
    self._finalize_captured_args = self._finalize_func.captured_inputs