def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ if not context.executing_eagerly(): with ops.device(self._device): ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return structure.from_compatible_tensor_list( self._element_spec, ret) # This runs in sync mode as iterators use an error status to communicate # that there is no more data to iterate over. # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): with ops.device(self._device): # TODO(ashankar): Consider removing this ops.device() contextmanager # and instead mimic ops placement in graphs: Operations on resource # handles execute on the same device as where the resource is placed. ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) try: # Fast path for the case `self._structure` is not a nested structure. return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access except AttributeError: return structure.from_compatible_tensor_list( self._element_spec, ret)
def _next_internal(self): if not context.executing_eagerly(): # TODO(b/169442955): Investigate the need for this colocation constraint. with ops.colocate_with(self._iterator_resource): ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return structure.from_compatible_tensor_list( self._element_spec, ret) # TODO(b/77291417): This runs in sync mode as iterators use an error status # to communicate that there is no more data to iterate over. with context.execution_mode(context.SYNC): ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) try: # Fast path for the case `self._structure` is not a nested structure. return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access except AttributeError: return structure.from_compatible_tensor_list( self._element_spec, ret)
def _next_internal(self): autograph_status = autograph_ctx.control_status_ctx().status autograph_disabled = autograph_status == autograph_ctx.Status.DISABLED if not context.executing_eagerly() and autograph_disabled: self._get_next_call_count += 1 if self._get_next_call_count > GET_NEXT_CALL_ERROR_THRESHOLD: raise ValueError(GET_NEXT_CALL_ERROR_MESSAGE) if not context.executing_eagerly(): # TODO(b/169442955): Investigate the need for this colocation constraint. with ops.colocate_with(self._iterator_resource): ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return structure.from_compatible_tensor_list( self._element_spec, ret) # TODO(b/77291417): This runs in sync mode as iterators use an error status # to communicate that there is no more data to iterate over. with context.execution_mode(context.SYNC): ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) try: # Fast path for the case `self._structure` is not a nested structure. return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access except AttributeError: return structure.from_compatible_tensor_list( self._element_spec, ret)
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ if not context.executing_eagerly(): with ops.device(self._device): ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return structure.from_compatible_tensor_list(self._element_spec, ret) # This runs in sync mode as iterators use an error status to communicate # that there is no more data to iterate over. # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): with ops.device(self._device): # TODO(ashankar): Consider removing this ops.device() contextmanager # and instead mimic ops placement in graphs: Operations on resource # handles execute on the same device as where the resource is placed. # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` # because in eager mode this code will run synchronously on the calling # thread. Therefore we do not need to make a defensive context switch # to a background thread, and can achieve a small constant performance # boost by invoking the iterator synchronously. ret = gen_dataset_ops.iterator_get_next_sync( self._iterator_resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) try: # Fast path for the case `self._structure` is not a nested structure. return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access except AttributeError: return structure.from_compatible_tensor_list(self._element_spec, ret)
def testNestedNestedStructure(self): s = (tensor_spec.TensorSpec([], dtypes.int64), (tensor_spec.TensorSpec([], dtypes.float32), tensor_spec.TensorSpec([], dtypes.string))) int64_t = constant_op.constant(37, dtype=dtypes.int64) float32_t = constant_op.constant(42.0) string_t = constant_op.constant("Foo") nested_tensors = (int64_t, (float32_t, string_t)) tensor_list = structure.to_tensor_list(s, nested_tensors) for expected, actual in zip([int64_t, float32_t, string_t], tensor_list): self.assertIs(expected, actual) (actual_int64_t, (actual_float32_t, actual_string_t)) = structure.from_tensor_list(s, tensor_list) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) (actual_int64_t, (actual_float32_t, actual_string_t)) = (structure.from_compatible_tensor_list( s, tensor_list)) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t)
def py_function_wrapper(*args): nested_args = structure.from_compatible_tensor_list( self._input_structure, args) if not _should_unpack(nested_args): nested_args = (nested_args, ) ret = self._func(*nested_args) if _should_pack(ret): ret = tuple(ret) ret = structure.to_tensor_list(self._output_structure, ret) return [ops.convert_to_tensor(t) for t in ret]
def wrapper_helper(*args): """Wrapper for passing nested structures to and from tf.data functions.""" nested_args = structure.from_compatible_tensor_list( self._input_structure, args) if not _should_unpack(nested_args): nested_args = (nested_args,) ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args) ret = variable_utils.convert_variables_to_tensors(ret) if _should_pack(ret): ret = tuple(ret) try: self._output_structure = structure.type_spec_from_value(ret) except (ValueError, TypeError) as e: raise TypeError(f"Unsupported return value from function passed to " f"{transformation_name}: {ret}.") from e return ret
def get_single_element(dataset): """Returns the single element in `dataset` as a nested structure of tensors. This function enables you to use a `tf.data.Dataset` in a stateless "tensor-in tensor-out" expression, without creating a `tf.compat.v1.data.Iterator`. This can be useful when your preprocessing transformations are expressed as a `Dataset`, and you want to use the transformation at serving time. For example: ```python input_batch = tf.compat.v1.placeholder(tf.string, shape=[BATCH_SIZE]) def preprocessing_fn(input_str): # ... return image, label dataset = (tf.data.Dataset.from_tensor_slices(input_batch) .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) .batch(BATCH_SIZE)) image_batch, label_batch = tf.data.experimental.get_single_element(dataset) ``` Args: dataset: A `tf.data.Dataset` object containing a single element. Returns: A nested structure of `tf.Tensor` objects, corresponding to the single element of `dataset`. Raises: TypeError: if `dataset` is not a `tf.data.Dataset` object. InvalidArgumentError (at runtime): if `dataset` does not contain exactly one element. """ if not isinstance(dataset, dataset_ops.DatasetV2): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") # pylint: disable=protected-access return structure.from_compatible_tensor_list( dataset.element_spec, gen_dataset_ops.dataset_to_single_element(dataset._variant_tensor, **dataset._flat_structure)) # pylint: disable=protected-access