def get_single_element(dataset): """Returns the single element in `dataset` as a nested structure of tensors. This function enables you to use a @{tf.data.Dataset} in a stateless "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. This can be useful when your preprocessing transformations are expressed as a `Dataset`, and you want to use the transformation at serving time. For example: ```python input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) def preprocessing_fn(input_str): # ... return image, label dataset = (tf.data.Dataset.from_tensor_slices(input_batch) .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) .batch(BATCH_SIZE)) image_batch, label_batch = tf.contrib.data.get_single_element(dataset) ``` Args: dataset: A @{tf.data.Dataset} object containing a single element. Returns: A nested structure of @{tf.Tensor} objects, corresponding to the single element of `dataset`. Raises: TypeError: if `dataset` is not a `tf.data.Dataset` object. InvalidArgumentError (at runtime): if `dataset` does not contain exactly one element. """ if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") nested_ret = nest.pack_sequence_as( dataset.output_types, gen_dataset_ops.dataset_to_single_element( dataset._as_variant_tensor(), # pylint: disable=protected-access output_types=nest.flatten(sparse.as_dense_types( dataset.output_types, dataset.output_classes)), output_shapes=nest.flatten(sparse.as_dense_shapes( dataset.output_shapes, dataset.output_classes)))) return sparse.deserialize_sparse_tensors( nested_ret, dataset.output_types, dataset.output_shapes, dataset.output_classes)
def get_single_element(dataset): """Returns the single element in `dataset` as a nested structure of tensors. This function enables you to use a `tf.data.Dataset` in a stateless "tensor-in tensor-out" expression, without creating a `tf.compat.v1.data.Iterator`. This can be useful when your preprocessing transformations are expressed as a `Dataset`, and you want to use the transformation at serving time. For example: ```python input_batch = tf.compat.v1.placeholder(tf.string, shape=[BATCH_SIZE]) def preprocessing_fn(input_str): # ... return image, label dataset = (tf.data.Dataset.from_tensor_slices(input_batch) .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) .batch(BATCH_SIZE)) image_batch, label_batch = tf.data.experimental.get_single_element(dataset) ``` Args: dataset: A `tf.data.Dataset` object containing a single element. Returns: A nested structure of `tf.Tensor` objects, corresponding to the single element of `dataset`. Raises: TypeError: if `dataset` is not a `tf.data.Dataset` object. InvalidArgumentError (at runtime): if `dataset` does not contain exactly one element. """ if not isinstance(dataset, dataset_ops.DatasetV2): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") # pylint: disable=protected-access return structure.from_compatible_tensor_list( dataset.element_spec, gen_dataset_ops.dataset_to_single_element(dataset._variant_tensor, **dataset._flat_structure)) # pylint: disable=protected-access
def get_single_element(dataset): """Returns the single element in `dataset` as a nested structure of tensors. This function enables you to use a `tf.data.Dataset` in a stateless "tensor-in tensor-out" expression, without creating a `tf.data.Iterator`. This can be useful when your preprocessing transformations are expressed as a `Dataset`, and you want to use the transformation at serving time. For example: ```python input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) def preprocessing_fn(input_str): # ... return image, label dataset = (tf.data.Dataset.from_tensor_slices(input_batch) .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) .batch(BATCH_SIZE)) image_batch, label_batch = tf.data.experimental.get_single_element(dataset) ``` Args: dataset: A `tf.data.Dataset` object containing a single element. Returns: A nested structure of `tf.Tensor` objects, corresponding to the single element of `dataset`. Raises: TypeError: if `dataset` is not a `tf.data.Dataset` object. InvalidArgumentError (at runtime): if `dataset` does not contain exactly one element. """ if not isinstance(dataset, dataset_ops.DatasetV2): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") # pylint: disable=protected-access return dataset._element_structure._from_compatible_tensor_list( gen_dataset_ops.dataset_to_single_element( dataset._variant_tensor, **dataset_ops.flat_structure(dataset)))