示例#1
0
    def _next_internal(self):
        """Returns a nested structure of `tf.Tensor`s containing the next element.
    """
        if not context.executing_eagerly():
            with ops.device(self._device):
                ret = gen_dataset_ops.iterator_get_next(
                    self._iterator_resource,
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes)
            return structure.from_compatible_tensor_list(
                self._element_spec, ret)

        # This runs in sync mode as iterators use an error status to communicate
        # that there is no more data to iterate over.
        # TODO(b/77291417): Fix
        with context.execution_mode(context.SYNC):
            with ops.device(self._device):
                # TODO(ashankar): Consider removing this ops.device() contextmanager
                # and instead mimic ops placement in graphs: Operations on resource
                # handles execute on the same device as where the resource is placed.
                ret = gen_dataset_ops.iterator_get_next(
                    self._iterator_resource,
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes)

            try:
                # Fast path for the case `self._structure` is not a nested structure.
                return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
            except AttributeError:
                return structure.from_compatible_tensor_list(
                    self._element_spec, ret)
示例#2
0
    def _next_internal(self):
        if not context.executing_eagerly():
            # TODO(b/169442955): Investigate the need for this colocation constraint.
            with ops.colocate_with(self._iterator_resource):
                ret = gen_dataset_ops.iterator_get_next(
                    self._iterator_resource,
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes)
            return structure.from_compatible_tensor_list(
                self._element_spec, ret)

        # TODO(b/77291417): This runs in sync mode as iterators use an error status
        # to communicate that there is no more data to iterate over.
        with context.execution_mode(context.SYNC):
            ret = gen_dataset_ops.iterator_get_next(
                self._iterator_resource,
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)

            try:
                # Fast path for the case `self._structure` is not a nested structure.
                return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
            except AttributeError:
                return structure.from_compatible_tensor_list(
                    self._element_spec, ret)
示例#3
0
    def _next_internal(self):
        autograph_status = autograph_ctx.control_status_ctx().status
        autograph_disabled = autograph_status == autograph_ctx.Status.DISABLED
        if not context.executing_eagerly() and autograph_disabled:
            self._get_next_call_count += 1
            if self._get_next_call_count > GET_NEXT_CALL_ERROR_THRESHOLD:
                raise ValueError(GET_NEXT_CALL_ERROR_MESSAGE)

        if not context.executing_eagerly():
            # TODO(b/169442955): Investigate the need for this colocation constraint.
            with ops.colocate_with(self._iterator_resource):
                ret = gen_dataset_ops.iterator_get_next(
                    self._iterator_resource,
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes)
            return structure.from_compatible_tensor_list(
                self._element_spec, ret)

        # TODO(b/77291417): This runs in sync mode as iterators use an error status
        # to communicate that there is no more data to iterate over.
        with context.execution_mode(context.SYNC):
            ret = gen_dataset_ops.iterator_get_next(
                self._iterator_resource,
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)

            try:
                # Fast path for the case `self._structure` is not a nested structure.
                return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
            except AttributeError:
                return structure.from_compatible_tensor_list(
                    self._element_spec, ret)
示例#4
0
  def _next_internal(self):
    """Returns a nested structure of `tf.Tensor`s containing the next element.
    """
    if not context.executing_eagerly():
      with ops.device(self._device):
        ret = gen_dataset_ops.iterator_get_next(
            self._iterator_resource,
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)
      return structure.from_compatible_tensor_list(self._element_spec, ret)

    # This runs in sync mode as iterators use an error status to communicate
    # that there is no more data to iterate over.
    # TODO(b/77291417): Fix
    with context.execution_mode(context.SYNC):
      with ops.device(self._device):
        # TODO(ashankar): Consider removing this ops.device() contextmanager
        # and instead mimic ops placement in graphs: Operations on resource
        # handles execute on the same device as where the resource is placed.
        # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
        # because in eager mode this code will run synchronously on the calling
        # thread. Therefore we do not need to make a defensive context switch
        # to a background thread, and can achieve a small constant performance
        # boost by invoking the iterator synchronously.
        ret = gen_dataset_ops.iterator_get_next_sync(
            self._iterator_resource,
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)

      try:
        # Fast path for the case `self._structure` is not a nested structure.
        return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
      except AttributeError:
        return structure.from_compatible_tensor_list(self._element_spec, ret)
示例#5
0
    def testNestedNestedStructure(self):
        s = (tensor_spec.TensorSpec([], dtypes.int64),
             (tensor_spec.TensorSpec([], dtypes.float32),
              tensor_spec.TensorSpec([], dtypes.string)))

        int64_t = constant_op.constant(37, dtype=dtypes.int64)
        float32_t = constant_op.constant(42.0)
        string_t = constant_op.constant("Foo")

        nested_tensors = (int64_t, (float32_t, string_t))

        tensor_list = structure.to_tensor_list(s, nested_tensors)
        for expected, actual in zip([int64_t, float32_t, string_t],
                                    tensor_list):
            self.assertIs(expected, actual)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = structure.from_tensor_list(s, tensor_list)
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = (structure.from_compatible_tensor_list(
              s, tensor_list))
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)
 def py_function_wrapper(*args):
     nested_args = structure.from_compatible_tensor_list(
         self._input_structure, args)
     if not _should_unpack(nested_args):
         nested_args = (nested_args, )
     ret = self._func(*nested_args)
     if _should_pack(ret):
         ret = tuple(ret)
     ret = structure.to_tensor_list(self._output_structure, ret)
     return [ops.convert_to_tensor(t) for t in ret]
示例#7
0
    def wrapper_helper(*args):
      """Wrapper for passing nested structures to and from tf.data functions."""
      nested_args = structure.from_compatible_tensor_list(
          self._input_structure, args)
      if not _should_unpack(nested_args):
        nested_args = (nested_args,)
      ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
      ret = variable_utils.convert_variables_to_tensors(ret)
      if _should_pack(ret):
        ret = tuple(ret)

      try:
        self._output_structure = structure.type_spec_from_value(ret)
      except (ValueError, TypeError) as e:
        raise TypeError(f"Unsupported return value from function passed to "
                        f"{transformation_name}: {ret}.") from e
      return ret
示例#8
0
def get_single_element(dataset):
    """Returns the single element in `dataset` as a nested structure of tensors.

  This function enables you to use a `tf.data.Dataset` in a stateless
  "tensor-in tensor-out" expression, without creating a
  `tf.compat.v1.data.Iterator`.
  This can be useful when your preprocessing transformations are expressed
  as a `Dataset`, and you want to use the transformation at serving time.
  For example:

  ```python
  input_batch = tf.compat.v1.placeholder(tf.string, shape=[BATCH_SIZE])

  def preprocessing_fn(input_str):
    # ...
    return image, label

  dataset = (tf.data.Dataset.from_tensor_slices(input_batch)
             .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
             .batch(BATCH_SIZE))

  image_batch, label_batch = tf.data.experimental.get_single_element(dataset)
  ```

  Args:
    dataset: A `tf.data.Dataset` object containing a single element.

  Returns:
    A nested structure of `tf.Tensor` objects, corresponding to the single
    element of `dataset`.

  Raises:
    TypeError: if `dataset` is not a `tf.data.Dataset` object.
    InvalidArgumentError (at runtime): if `dataset` does not contain exactly
      one element.
  """
    if not isinstance(dataset, dataset_ops.DatasetV2):
        raise TypeError("`dataset` must be a `tf.data.Dataset` object.")

    # pylint: disable=protected-access
    return structure.from_compatible_tensor_list(
        dataset.element_spec,
        gen_dataset_ops.dataset_to_single_element(dataset._variant_tensor,
                                                  **dataset._flat_structure))  # pylint: disable=protected-access