def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: classes = sparse.get_classes(expected) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), types, shapes, sparse.get_classes(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def from_string_handle(string_handle, output_types, output_shapes=None): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose between concrete iterators by feeding a value in a @{tf.Session.run} call. In that case, `string_handle` would a @{tf.placeholder}, and you would feed it with the value of @{tf.data.Iterator.string_handle} in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in each step as follows: ```python train_iterator = tf.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`) objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`) component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. Returns: An `Iterator`. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) nest.assert_same_structure(output_types, output_shapes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) iterator_resource = gen_dataset_ops.iterator_from_string_handle( string_handle, output_types=nest.flatten(sparse.unwrap_sparse_types(output_types)), output_shapes=nest.flatten(output_shapes)) return Iterator(iterator_resource, None, output_types, output_shapes)
def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_workers.""" if len(output_shapes) < 1: raise ValueError("Input shape should have at least one dimension.") if (tensor_shape.dimension_value(output_shapes[0]) and tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0): raise errors.InvalidArgumentError( None, None, "First dim of input shape: %d is not divisible by num_workers: %d" % (output_shapes[0], num_workers)) output_dims = [d for d in output_shapes.dims] output_dims[0] = output_dims[0] // num_workers return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._structure = structure.convert_legacy_structure( input_types, output_shapes, input_classes) variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **dataset_ops.flat_structure(self)) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" tensor_batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") flattened = _RestructuredDataset( dataset, tuple(nest.flatten(dataset.output_types)), output_classes=tuple(nest.flatten(dataset.output_classes))) def _predicate(*xs): """Return `True` if this element is a full batch.""" # Extract the dynamic batch size from the first component of the flattened # batched element. first_component = xs[0] first_component_batch_size = array_ops.shape( first_component, out_type=dtypes.int64)[0] return math_ops.equal(first_component_batch_size, tensor_batch_size) filtered = flattened.filter(_predicate) maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) def _set_first_dimension(shape): return shape.merge_with( tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) known_shapes = nest.map_structure(_set_first_dimension, dataset.output_shapes) return _RestructuredDataset( filtered, dataset.output_types, known_shapes, output_classes=dataset.output_classes)
def __init__(self, dataset, output_types, output_shapes=None): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: * `dataset.output_types` must be the same as `output_types` module nesting. * Each shape in `dataset.output_shapes` must be compatible with each shape in `output_shapes` (if given). Note: This helper permits "unsafe casts" for shapes, equivalent to using `tf.Tensor.set_shape()` where domain-specific knowledge is available. Args: dataset: A `Dataset` object. output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ super(_RestructuredDataset, self).__init__() self._dataset = dataset # Validate that the types are compatible. output_types = nest.map_structure(dtypes.as_dtype, output_types) flat_original_types = nest.flatten(dataset.output_types) flat_new_types = nest.flatten(output_types) if flat_original_types != flat_new_types: raise ValueError( "Dataset with output types %r cannot be restructured to have output " "types %r" % (dataset.output_types, output_types)) self._output_types = output_types if output_shapes is None: # Inherit shapes from the original `dataset`. self._output_shapes = nest.pack_sequence_as(output_types, nest.flatten( dataset.output_shapes)) else: # Validate that the shapes are compatible. nest.assert_same_structure(output_types, output_shapes) flat_original_shapes = nest.flatten(dataset.output_shapes) flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) for original_shape, new_shape in zip(flat_original_shapes, flat_new_shapes): if not original_shape.is_compatible_with(new_shape): raise ValueError( "Dataset with output shapes %r cannot be restructured to have " "incompatible output shapes %r" % (dataset.output_shapes, output_shapes)) self._output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes)
def __init__(self, variant_tensor, output_shapes, output_types, output_classes): # TODO(b/110122868): Consolidate the structure validation logic with the # similar logic in `Iterator.from_structure()` and # `Dataset.from_generator()`. output_types = nest.map_structure(dtypes.as_dtype, output_types) output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) nest.assert_same_structure(output_types, output_shapes) nest.assert_same_structure(output_types, output_classes) self._variant_tensor = variant_tensor self._output_shapes = output_shapes self._output_types = output_types self._output_classes = output_classes
def __init__(self, driver_name, data_source_name, query, output_types): """Creates a `SqlDataset`. `SqlDataset` allows a user to read data from the result set of a SQL query. For example: ```python tf.enable_eager_execution() dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3", "SELECT name, age FROM people", (tf.string, tf.int32)) # Prints the rows of the result set of the above query. for element in dataset: print(element) ``` Args: driver_name: A 0-D `tf.string` tensor containing the database type. Currently, the only supported value is 'sqlite'. data_source_name: A 0-D `tf.string` tensor containing a connection string to connect to the database. query: A 0-D `tf.string` tensor containing the SQL query to execute. output_types: A tuple of `tf.DType` objects representing the types of the columns returned by `query`. """ self._driver_name = ops.convert_to_tensor( driver_name, dtype=dtypes.string, name="driver_name") self._data_source_name = ops.convert_to_tensor( data_source_name, dtype=dtypes.string, name="data_source_name") self._query = ops.convert_to_tensor( query, dtype=dtypes.string, name="query") self._structure = structure.NestedStructure( nest.map_structure( lambda dtype: structure.TensorStructure(dtype, []), output_types)) variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset( self._driver_name, self._data_source_name, self._query, **dataset_ops.flat_structure(self)) super(SqlDatasetV2, self).__init__(variant_tensor)
def __init__(self, input_dataset): """See `unbatch()` for more details.""" flat_shapes = nest.flatten(input_dataset.output_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") known_batch_dim = tensor_shape.Dimension(None) for s in flat_shapes: try: known_batch_dim = known_batch_dim.merge_with(s[0]) except ValueError: raise ValueError("Cannot unbatch an input whose components have " "different batch sizes.") self._input_dataset = input_dataset self._structure = structure.convert_legacy_structure( input_dataset.output_types, nest.map_structure(lambda s: s[1:], input_dataset.output_shapes), input_dataset.output_classes) variant_tensor = ged_ops.experimental_unbatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access **dataset_ops.flat_structure(self)) super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values): """Initialize `PrependFromQueueAndPaddedBatchDataset`.""" super(_PrependFromQueueAndPaddedBatchDataset, self).__init__() if sparse.any_sparse(input_dataset.output_classes): raise TypeError( "Batching of padded sparse tensors is not currently supported") self._input_dataset = input_dataset self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") # pylint: disable=protected-access if padded_shapes is None: self._padded_shapes = nest.map_structure( dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes) else: self._padded_shapes = nest.map_structure_up_to( input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor, padded_shapes) padding_values = ( padding_values if padding_values is not None else dataset_ops._default_padding(input_dataset)) self._padding_values = nest.map_structure_up_to( input_dataset.output_shapes, dataset_ops._padding_value_to_tensor, padding_values, input_dataset.output_types)
def __init__(self, dataset, output_types, output_shapes=None, output_classes=None, allow_unsafe_cast=False): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: * `dataset.output_types` must be the same as `output_types` module nesting. * Each shape in `dataset.output_shapes` must be compatible with each shape in `output_shapes` (if given). Note: This helper permits "unsafe casts" for shapes, equivalent to using `tf.Tensor.set_shape()` where domain-specific knowledge is available. Args: dataset: A `Dataset` object. output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. output_classes: (Optional.) A nested structure of class types. If omitted, the class types will be inherited from `dataset`. allow_unsafe_cast: (Optional.) If `True`, the caller may switch the reported output types and shapes of the restructured dataset, e.g. to switch a sparse tensor represented as `tf.variant` to its user-visible type and shape. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ self._input_dataset = dataset input_types = dataset_ops.get_legacy_output_types(dataset) if not allow_unsafe_cast: # Validate that the types are compatible. output_types = nest.map_structure(dtypes.as_dtype, output_types) flat_original_types = nest.flatten(input_types) flat_new_types = nest.flatten(output_types) if flat_original_types != flat_new_types: raise ValueError( "Dataset with output types %r cannot be restructured to have " "output types %r" % (dataset_ops.get_legacy_output_types(dataset), output_types)) input_shapes = dataset_ops.get_legacy_output_shapes(dataset) if output_shapes is None: # Inherit shapes from the original `dataset`. output_shapes = nest.pack_sequence_as( output_types, nest.flatten(input_shapes)) else: if not allow_unsafe_cast: # Validate that the shapes are compatible. nest.assert_same_structure(output_types, output_shapes) flat_original_shapes = nest.flatten(input_shapes) flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) for original_shape, new_shape in zip(flat_original_shapes, flat_new_shapes): if not original_shape.is_compatible_with(new_shape): raise ValueError( "Dataset with output shapes %r cannot be restructured to have " "incompatible output shapes %r" % (input_shapes, output_shapes)) output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) input_classes = dataset_ops.get_legacy_output_classes(dataset) if output_classes is None: # Inherit class types from the original `dataset`. output_classes = nest.pack_sequence_as( output_types, nest.flatten(input_classes)) self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
def __init__(self, input_dataset, initial_state, scan_func): """See `scan()` for details.""" self._input_dataset = input_dataset self._initial_state = structure.normalize_element(initial_state) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_structure = structure.type_spec_from_value( self._initial_state) # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( scan_func, self._transformation_name(), input_structure=(self._state_structure, input_dataset._element_structure), # pylint: disable=protected-access add_to_graph=False) if not (isinstance(wrapped_func.output_types, collections.Sequence) and len(wrapped_func.output_types) == 2): raise TypeError( "The scan function must return a pair comprising the " "new state and the output value.") new_state_classes, self._output_classes = wrapped_func.output_classes # Extract and validate class information from the returned values. new_state_classes, output_classes = wrapped_func.output_classes old_state_classes = nest.map_structure( lambda component_spec: component_spec. _to_legacy_output_classes(), # pylint: disable=protected-access self._state_structure) for new_state_class, old_state_class in zip( nest.flatten(new_state_classes), nest.flatten(old_state_classes)): if not issubclass(new_state_class, old_state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (old_state_classes, new_state_classes)) # Extract and validate type information from the returned values. new_state_types, output_types = wrapped_func.output_types old_state_types = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_types( ), # pylint: disable=protected-access self._state_structure) for new_state_type, old_state_type in zip( nest.flatten(new_state_types), nest.flatten(old_state_types)): if new_state_type != old_state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (old_state_types, new_state_types)) # Extract shape information from the returned values. new_state_shapes, output_shapes = wrapped_func.output_shapes old_state_shapes = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_shapes( ), # pylint: disable=protected-access self._state_structure) self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) flat_state_shapes = nest.flatten(old_state_shapes) flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: # TODO(b/110122868): Support a "most specific compatible structure" # method for combining structures, to avoid using legacy structures # in this method. self._state_structure = structure.convert_legacy_structure( old_state_types, nest.pack_sequence_as(old_state_shapes, weakened_state_shapes), old_state_classes) self._scan_func = wrapped_func self._scan_func.function.add_to_graph(ops.get_default_graph()) # pylint: disable=protected-access variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset( self._input_dataset._variant_tensor, structure.to_tensor_list(self._state_structure, self._initial_state), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, **self._flat_structure) super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
def _batch(self, batch_size): return NestedStructure( nest.map_structure(lambda s: s._batch(batch_size), self._nested_structure))
def from_string_handle(string_handle, output_types, output_shapes=None, output_classes=None): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose between concrete iterators by feeding a value in a `tf.Session.run` call. In that case, `string_handle` would be a `tf.placeholder`, and you would feed it with the value of `tf.data.Iterator.string_handle` in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in each step as follows: ```python train_iterator = tf.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure_lib.convert_legacy_structure( output_types, output_shapes, output_classes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) # pylint: disable=protected-access if compat.forward_compatible(2018, 8, 3): if _device_stack_is_empty(): with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle( string_handle, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) # pylint: enable=protected-access return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def _batch(self, batch_size): return NestedStructure(nest.map_structure( lambda s: s._batch(batch_size), self._nested_structure))
def make_initializer(self, dataset, name=None): """Returns a `tf.Operation` that initializes this iterator on `dataset`. Args: dataset: A `Dataset` whose `element_spec` if compatible with this iterator. name: (Optional.) A name for the created operation. Returns: A `tf.Operation` that can be run to initialize this iterator on the given `dataset`. Raises: TypeError: If `dataset` and this iterator do not have a compatible `element_spec`. """ with ops.name_scope(name, "make_initializer") as name: # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due # to that creating a circular dependency. # pylint: disable=protected-access dataset_output_types = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_types( ), dataset.element_spec) dataset_output_shapes = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_shapes( ), dataset.element_spec) dataset_output_classes = nest.map_structure( lambda component_spec: component_spec. _to_legacy_output_classes(), dataset.element_spec) # pylint: enable=protected-access nest.assert_same_structure(self.output_types, dataset_output_types) nest.assert_same_structure(self.output_shapes, dataset_output_shapes) for iterator_class, dataset_class in zip( nest.flatten(self.output_classes), nest.flatten(dataset_output_classes)): if iterator_class is not dataset_class: raise TypeError( f"Expected output classes {self.output_classes!r} but got " f"dataset with output classes {dataset_output_classes!r}." ) for iterator_dtype, dataset_dtype in zip( nest.flatten(self.output_types), nest.flatten(dataset_output_types)): if iterator_dtype != dataset_dtype: raise TypeError( f"Expected output types {self.output_types!r} but got dataset " f"with output types {dataset_output_types!r}.") for iterator_shape, dataset_shape in zip( nest.flatten(self.output_shapes), nest.flatten(dataset_output_shapes)): if not iterator_shape.is_compatible_with(dataset_shape): raise TypeError( f"Expected output shapes compatible with {self.output_shapes!r} " f"but got dataset with output shapes {dataset_output_shapes!r}." ) # TODO(b/169442955): Investigate the need for this colocation constraint. with ops.colocate_with(self._iterator_resource): # pylint: disable=protected-access return gen_dataset_ops.make_iterator(dataset._variant_tensor, self._iterator_resource, name=name)
def from_generator(generator, output_types, output_shapes=None): """Creates a `Dataset` whose elements are generated by `generator`. The `generator` argument must be a callable object that returns an object that support the `iter()` protocol (e.g. a generator function). The elements generated by `generator` must be compatible with the given `output_types` and (optional) `output_shapes` arguments. For example: ```python import itertools def gen(): for i in itertools.count(1): yield (i, [1] * i) ds = Dataset.from_generator( gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) value = ds.make_one_shot_iterator().get_next() sess.run(value) # (1, array([1])) sess.run(value) # (2, array([1, 1])) ``` Args: generator: A callable object that takes no arguments and returns an object that supports the `iter()` protocol. output_types: A nested structure of `tf.DType` objects corresponding to each component of an element yielded by `generator`. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element yielded by `generator`. Returns: A `Dataset`. """ if not callable(generator): raise TypeError("`generator` must be callable.") if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) flattened_types = nest.flatten(output_types) flattened_shapes = nest.flatten(output_shapes) generator_state = dataset_ops.Dataset._GeneratorState(generator) def get_iterator_id_map_fn(unused_dummy): """Creates a unique `iterator_id` for each pass over the dataset. The "iterator_id" disambiguates between multiple concurrently existing iterators. Args: unused_dummy: Ignored value. Returns: A `tf.int64` tensor whose value uniquely identifies an iterator in `generator_state`. """ return script_ops.py_func( generator_state.get_next_id, [], dtypes.int64, stateful=True) def generator_map_fn(iterator_id_t): """Generates the next element from iterator with ID `iterator_id_t`. We map this function across an infinite repetition of the `iterator_id_t`, and raise `StopIteration` to terminate the iteration. Args: iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the iterator in `generator_state` from which to generate an element. Returns: A nested structure of tensors representing an element from the iterator. """ def generator_py_func(iterator_id): """A `py_func` that will be called to invoke the iterator.""" try: values = next(generator_state.get_iterator(iterator_id)) except StopIteration: generator_state.iterator_completed(iterator_id) raise StopIteration("Iteration finished.") # Use the same _convert function from the py_func() implementation to # convert the returned values to arrays early, so that we can inspect # their values. # pylint: disable=protected-access ret_arrays = [ script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype) for ret, dtype in zip(nest.flatten_up_to(output_types, values), flattened_types) ] # pylint: enable=protected-access # Additional type and shape checking to ensure that the components # of the generated element match the `output_types` and `output_shapes` # arguments. for (ret_array, expected_dtype, expected_shape) in zip( ret_arrays, flattened_types, flattened_shapes): if ret_array.dtype != expected_dtype.as_numpy_dtype: raise TypeError( "`generator` yielded an element of type %s where an element " "of type %s was expected." % (ret_array.dtype, expected_dtype.as_numpy_dtype)) if not expected_shape.is_compatible_with(ret_array.shape): raise ValueError( "`generator` yielded an element of shape %s where an element " "of shape %s was expected." % (ret_array.shape, expected_shape)) return ret_arrays flat_values = script_ops.py_func( generator_py_func, [iterator_id_t], flattened_types, stateful=True) # The `py_func()` op drops the inferred shapes, so we add them back in # here. if output_shapes is not None: for ret_t, shape in zip(flat_values, flattened_shapes): ret_t.set_shape(shape) return nest.pack_sequence_as(output_types, flat_values) # This function associates each traversal of `generator` with a unique # iterator ID. def flat_map_fn(iterator_id_t): # First, generate an infinite dataset containing the iterator ID repeated # forever. repeated_id = Dataset.from_tensors(iterator_id_t).repeat(None) # The `generator_map_fn` gets the next element from the iterator with the # relevant ID, and raises StopIteration when that iterator contains no # more elements. return repeated_id.map(generator_map_fn) # A single-element dataset that, each time it is evaluated, contains a # freshly-generated and unique (for the returned dataset) int64 # ID that will be used to identify the appropriate Python state, which # is encapsulated in `generator_state`, and captured in # `get_iterator_id_map_fn`. dummy = 0 id_dataset = Dataset.from_tensors(dummy).map(get_iterator_id_map_fn) # A dataset that contains all of the elements generated by a # single iterator created from `generator`, identified by the # iterator ID contained in `id_dataset`. Lifting the iteration # into a flat_map here enables multiple repetitions and/or nested # versions of the returned dataset to be created, because it forces # the generation of a new ID for each version. return id_dataset.flat_map(flat_map_fn)
def from_structure(output_types, output_shapes=None, shared_name=None, output_classes=None): """Creates a new, uninitialized `Iterator` with the given structure. This iterator-constructing method can be used to create an iterator that is reusable with many different datasets. The returned iterator is not bound to a particular dataset, and it has no `initializer`. To initialize the iterator, run the operation returned by `Iterator.make_initializer(dataset)`. The following is an example ```python iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) dataset_range = Dataset.range(10) range_initializer = iterator.make_initializer(dataset_range) dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) evens_initializer = iterator.make_initializer(dataset_evens) # Define a model based on the iterator; in this example, the model_fn # is expected to take scalar tf.int64 Tensors as input (see # the definition of 'iterator' above). prediction, loss = model_fn(iterator.get_next()) # Train for `num_epochs`, where for each epoch, we first iterate over # dataset_range, and then iterate over dataset_evens. for _ in range(num_epochs): # Initialize the iterator to `dataset_range` sess.run(range_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break # Initialize the iterator to `dataset_evens` sess.run(evens_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break ``` Args: output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. shared_name: (Optional.) If non-empty, this iterator will be shared under the given name across multiple sessions that share the same devices (e.g. when using a remote server). output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. Raises: TypeError: If the structures of `output_shapes` and `output_types` are not the same. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) if shared_name is None: shared_name = "" iterator_resource = gen_dataset_ops.iterator( container="", shared_name=shared_name, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def output_shapes(self): # First output is a variant representing the Queue return (tensor_shape.vector(None), nest.map_structure(self._as_batch_shape, self._padded_shapes))
def testBatch(self, element_structure, batch_size, expected_batched_structure): batched_structure = nest.map_structure( lambda component_spec: component_spec._batch(batch_size), element_structure) self.assertEqual(batched_structure, expected_batched_structure)
def testUnbatch(self, element_structure, expected_unbatched_structure): unbatched_structure = nest.map_structure( lambda component_spec: component_spec._unbatch(), element_structure) self.assertEqual(unbatched_structure, expected_unbatched_structure)
def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = (((7, 8), 9), 10, (11, 12)) structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) nest.assert_same_structure(structure1, structure1_plus1) self.assertAllEqual( [2, 3, 4, 5, 6, 7], nest.flatten(structure1_plus1)) structure1_plus_structure2 = nest.map_structure( lambda x, y: x + y, structure1, structure2) self.assertEqual( (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), structure1_plus_structure2) self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) with self.assertRaisesRegexp(TypeError, "callable"): nest.map_structure("bad", structure1_plus1) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, 3, (3,)) with self.assertRaisesRegexp(TypeError, "same sequence type"): nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5}) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), check_types=False) with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): nest.map_structure(lambda x: None, structure1, foo="a") with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
def __init__(self, pipeline, output_dtypes=None, output_shapes=None, fail_on_device_mismatch=True, *, input_datasets=None, batch_size=1, num_threads=4, device_id=0, exec_separated=False, prefetch_queue_depth=2, cpu_prefetch_queue_depth=2, gpu_prefetch_queue_depth=2, dtypes=None, shapes=None): output_shapes = self._handle_deprecation(output_shapes, shapes, "shapes") output_dtypes = self._handle_deprecation(output_dtypes, dtypes, "dtypes") if not self._check_dtypes(output_dtypes, tf.DType): raise TypeError(("`output_dtypes` should be provided as single tf.DType value " "or a tuple of tf.DType values. Got value `{}` of type `{}`.") \ .format(output_dtypes, type(output_dtypes))) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_dtypes) else: output_shapes = nest.map_structure_up_to( output_dtypes, tensor_shape.as_shape, output_shapes) if not isinstance(output_dtypes, tuple): output_dtypes = (output_dtypes, ) output_shapes = (output_shapes, ) output_classes = nest.map_structure(lambda _: ops.Tensor, output_dtypes) self._pipeline_instance = pipeline # keep the live Pipeline object self._pipeline_serialized = serialize_pipeline(pipeline) self._batch_size = batch_size self._num_threads = num_threads if device_id is None: device_id = types.CPU_ONLY_DEVICE_ID self._device_id = device_id self._exec_separated = exec_separated self._prefetch_queue_depth = prefetch_queue_depth self._cpu_prefetch_queue_depth = cpu_prefetch_queue_depth self._gpu_prefetch_queue_depth = gpu_prefetch_queue_depth self._output_shapes = output_shapes self._output_dtypes = output_dtypes self._fail_on_device_mismatch = fail_on_device_mismatch self._setup_inputs(input_datasets) self._structure = structure.convert_legacy_structure( self._output_dtypes, self._output_shapes, output_classes) super(_DALIDatasetV2, self).__init__(self._as_variant_tensor())
def _unbatch(self): return NestedStructure(nest.map_structure( lambda s: s._unbatch(), self._nested_structure))
def _to_legacy_output_classes(self): return nest.map_structure( lambda s: s._to_legacy_output_classes(), self._nested_structure)
def _unbatch(self): return NestedStructure( nest.map_structure(lambda s: s._unbatch(), self._nested_structure))
def _to_legacy_output_classes(self): return nest.map_structure(lambda s: s._to_legacy_output_classes(), self._nested_structure)
def output_classes(self): return nest.map_structure(lambda _: ops.Tensor, self._output_types)
def output_shapes(self): return nest.map_structure(lambda s: s[1:], self._input_dataset.output_shapes)
def output_shapes(self): # First output is a variant representing the Queue return (tensor_shape.vector(None), nest.map_structure(self._as_batch_shape, self._padded_shapes))
def output_types(self): return nest.map_structure( lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access self._output_structure)
def from_generator(generator, output_types, output_shapes=None): """Creates a `Dataset` whose elements are generated by `generator`. The `generator` argument must be a callable object that returns an object that support the `iter()` protocol (e.g. a generator function). The elements generated by `generator` must be compatible with the given `output_types` and (optional) `output_shapes` arguments. For example: ```python import itertools def gen(): for i in itertools.count(1): yield (i, [1] * i) ds = Dataset.from_generator( gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) value = ds.make_one_shot_iterator().get_next() sess.run(value) # (1, array([1])) sess.run(value) # (2, array([1, 1])) ``` Args: generator: A callable object that takes no arguments and returns an object that supports the `iter()` protocol. output_types: A nested structure of `tf.DType` objects corresponding to each component of an element yielded by `generator`. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element yielded by `generator`. Returns: A `Dataset`. """ if not callable(generator): raise TypeError("`generator` must be callable.") if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) flattened_types = nest.flatten(output_types) flattened_shapes = nest.flatten(output_shapes) generator_state = dataset_ops.Dataset._GeneratorState(generator) def get_iterator_id_map_fn(unused_dummy): """Creates a unique `iterator_id` for each pass over the dataset. The "iterator_id" disambiguates between multiple concurrently existing iterators. Args: unused_dummy: Ignored value. Returns: A `tf.int64` tensor whose value uniquely identifies an iterator in `generator_state`. """ return script_ops.py_func(generator_state.get_next_id, [], dtypes.int64, stateful=True) def generator_map_fn(iterator_id_t): """Generates the next element from iterator with ID `iterator_id_t`. We map this function across an infinite repetition of the `iterator_id_t`, and raise `StopIteration` to terminate the iteration. Args: iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the iterator in `generator_state` from which to generate an element. Returns: A nested structure of tensors representing an element from the iterator. """ def generator_py_func(iterator_id): """A `py_func` that will be called to invoke the iterator.""" try: values = next(generator_state.get_iterator(iterator_id)) except StopIteration: generator_state.iterator_completed(iterator_id) raise StopIteration("Iteration finished.") # Use the same _convert function from the py_func() implementation to # convert the returned values to arrays early, so that we can inspect # their values. # pylint: disable=protected-access ret_arrays = [ script_ops.FuncRegistry._convert( ret, dtype=dtype.as_numpy_dtype) for ret, dtype in zip( nest.flatten_up_to(output_types, values), flattened_types) ] # pylint: enable=protected-access # Additional type and shape checking to ensure that the components # of the generated element match the `output_types` and `output_shapes` # arguments. for (ret_array, expected_dtype, expected_shape) in zip(ret_arrays, flattened_types, flattened_shapes): if ret_array.dtype != expected_dtype.as_numpy_dtype: raise TypeError( "`generator` yielded an element of type %s where an element " "of type %s was expected." % (ret_array.dtype, expected_dtype.as_numpy_dtype)) if not expected_shape.is_compatible_with(ret_array.shape): raise ValueError( "`generator` yielded an element of shape %s where an element " "of shape %s was expected." % (ret_array.shape, expected_shape)) return ret_arrays flat_values = script_ops.py_func(generator_py_func, [iterator_id_t], flattened_types, stateful=True) # The `py_func()` op drops the inferred shapes, so we add them back in # here. if output_shapes is not None: for ret_t, shape in zip(flat_values, flattened_shapes): ret_t.set_shape(shape) return nest.pack_sequence_as(output_types, flat_values) # This function associates each traversal of `generator` with a unique # iterator ID. def flat_map_fn(iterator_id_t): # First, generate an infinite dataset containing the iterator ID repeated # forever. repeated_id = Dataset.from_tensors(iterator_id_t).repeat(None) # The `generator_map_fn` gets the next element from the iterator with the # relevant ID, and raises StopIteration when that iterator contains no # more elements. return repeated_id.map(generator_map_fn) # A single-element dataset that, each time it is evaluated, contains a # freshly-generated and unique (for the returned dataset) int64 # ID that will be used to identify the appropriate Python state, which # is encapsulated in `generator_state`, and captured in # `get_iterator_id_map_fn`. dummy = 0 id_dataset = Dataset.from_tensors(dummy).map(get_iterator_id_map_fn) # A dataset that contains all of the elements generated by a # single iterator created from `generator`, identified by the # iterator ID contained in `id_dataset`. Lifting the iteration # into a flat_map here enables multiple repetitions and/or nested # versions of the returned dataset to be created, because it forces # the generation of a new ID for each version. return id_dataset.flat_map(flat_map_fn)
def __init__(self, dataset, output_types, output_shapes=None, output_classes=None): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: * `dataset.output_types` must be the same as `output_types` module nesting. * Each shape in `dataset.output_shapes` must be compatible with each shape in `output_shapes` (if given). Note: This helper permits "unsafe casts" for shapes, equivalent to using `tf.Tensor.set_shape()` where domain-specific knowledge is available. Args: dataset: A `Dataset` object. output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. output_classes: (Optional.) A nested structure of class types. If omitted, the class types will be inherited from `dataset`. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ super(_RestructuredDataset, self).__init__() self._dataset = dataset # Validate that the types are compatible. output_types = nest.map_structure(dtypes.as_dtype, output_types) flat_original_types = nest.flatten(dataset.output_types) flat_new_types = nest.flatten(output_types) if flat_original_types != flat_new_types: raise ValueError( "Dataset with output types %r cannot be restructured to have output " "types %r" % (dataset.output_types, output_types)) self._output_types = output_types if output_shapes is None: # Inherit shapes from the original `dataset`. self._output_shapes = nest.pack_sequence_as( output_types, nest.flatten(dataset.output_shapes)) else: # Validate that the shapes are compatible. nest.assert_same_structure(output_types, output_shapes) flat_original_shapes = nest.flatten(dataset.output_shapes) flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) for original_shape, new_shape in zip(flat_original_shapes, flat_new_shapes): if not original_shape.is_compatible_with(new_shape): raise ValueError( "Dataset with output shapes %r cannot be restructured to have " "incompatible output shapes %r" % (dataset.output_shapes, output_shapes)) self._output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: # Inherit class types from the original `dataset`. self._output_classes = nest.pack_sequence_as( output_types, nest.flatten(dataset.output_classes)) else: self._output_classes = output_classes
def __init__(self, dataset, output_types, output_shapes=None, output_classes=None, allow_unsafe_cast=False): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: * `dataset.output_types` must be the same as `output_types` module nesting. * Each shape in `dataset.output_shapes` must be compatible with each shape in `output_shapes` (if given). Note: This helper permits "unsafe casts" for shapes, equivalent to using `tf.Tensor.set_shape()` where domain-specific knowledge is available. Args: dataset: A `Dataset` object. output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. output_classes: (Optional.) A nested structure of class types. If omitted, the class types will be inherited from `dataset`. allow_unsafe_cast: (Optional.) If `True`, the caller may switch the reported output types and shapes of the restructured dataset, e.g. to switch a sparse tensor represented as `tf.variant` to its user-visible type and shape. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ self._input_dataset = dataset if not allow_unsafe_cast: # Validate that the types are compatible. output_types = nest.map_structure(dtypes.as_dtype, output_types) flat_original_types = nest.flatten(dataset.output_types) flat_new_types = nest.flatten(output_types) if flat_original_types != flat_new_types: raise ValueError( "Dataset with output types %r cannot be restructured to have " "output types %r" % (dataset.output_types, output_types)) if output_shapes is None: # Inherit shapes from the original `dataset`. output_shapes = nest.pack_sequence_as( output_types, nest.flatten(dataset.output_shapes)) else: if not allow_unsafe_cast: # Validate that the shapes are compatible. nest.assert_same_structure(output_types, output_shapes) flat_original_shapes = nest.flatten(dataset.output_shapes) flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) for original_shape, new_shape in zip(flat_original_shapes, flat_new_shapes): if not original_shape.is_compatible_with(new_shape): raise ValueError( "Dataset with output shapes %r cannot be restructured to have " "incompatible output shapes %r" % (dataset.output_shapes, output_shapes)) output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: # Inherit class types from the original `dataset`. output_classes = nest.pack_sequence_as( output_types, nest.flatten(dataset.output_classes)) self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
def from_structure(output_types, output_shapes=None, shared_name=None, output_classes=None): """Creates a new, uninitialized `Iterator` with the given structure. This iterator-constructing method can be used to create an iterator that is reusable with many different datasets. The returned iterator is not bound to a particular dataset, and it has no `initializer`. To initialize the iterator, run the operation returned by `Iterator.make_initializer(dataset)`. The following is an example ```python iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) dataset_range = Dataset.range(10) range_initializer = iterator.make_initializer(dataset_range) dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) evens_initializer = iterator.make_initializer(dataset_evens) # Define a model based on the iterator; in this example, the model_fn # is expected to take scalar tf.int64 Tensors as input (see # the definition of 'iterator' above). prediction, loss = model_fn(iterator.get_next()) # Train for `num_epochs`, where for each epoch, we first iterate over # dataset_range, and then iterate over dataset_evens. for _ in range(num_epochs): # Initialize the iterator to `dataset_range` sess.run(range_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break # Initialize the iterator to `dataset_evens` sess.run(evens_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break ``` Args: output_types: A (nested) structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. shared_name: (Optional.) If non-empty, this iterator will be shared under the given name across multiple sessions that share the same devices (e.g. when using a remote server). output_classes: (Optional.) A (nested) structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. Raises: TypeError: If the structures of `output_shapes` and `output_types` are not the same. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) if shared_name is None: shared_name = "" iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, output_types=structure.get_flat_tensor_types(output_structure), output_shapes=structure.get_flat_tensor_shapes( output_structure)) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def from_string_handle(string_handle, output_types, output_shapes=None, output_classes=None): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose between concrete iterators by feeding a value in a `tf.Session.run` call. In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you would feed it with the value of `tf.data.Iterator.string_handle` in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in each step as follows: ```python train_iterator = tf.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.compat.v1.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. output_types: A (nested) structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. output_classes: (Optional.) A (nested) structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=structure.get_flat_tensor_types(output_structure), output_shapes=structure.get_flat_tensor_shapes(output_structure)) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def output_shapes(self): return nest.map_structure(lambda s: s[1:], self._input_dataset.output_shapes)
def output_classes(self): return nest.map_structure(lambda _: ops.Tensor, self._output_types)
def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = (((7, 8), 9), 10, (11, 12)) structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) nest.assert_same_structure(structure1, structure1_plus1) self.assertAllEqual( [2, 3, 4, 5, 6, 7], nest.flatten(structure1_plus1)) structure1_plus_structure2 = nest.map_structure( lambda x, y: x + y, structure1, structure2) self.assertEqual( (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), structure1_plus_structure2) self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) with self.assertRaisesRegexp(TypeError, "callable"): nest.map_structure("bad", structure1_plus1) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, 3, (3,)) with self.assertRaisesRegexp(TypeError, "same sequence type"): nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5}) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), check_types=False) with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): nest.map_structure(lambda x: None, structure1, foo="a") with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
def output_shapes(self): return nest.map_structure(lambda _: tensor_shape.TensorShape([]), self._output_types)
def output_shapes(self): return nest.map_structure(lambda _: tensor_shape.TensorShape([]), self._output_types)
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder, use_legacy_function=False): """See `Dataset.map()` for details.""" self._input_dataset = input_dataset self._map_func = dataset_ops.StructuredFunctionWrapper( map_func, "tf.data.experimental.map_and_batch()", dataset=input_dataset, use_legacy_function=use_legacy_function) self._batch_size_t = ops.convert_to_tensor(batch_size, dtype=dtypes.int64, name="batch_size") self._num_parallel_calls_t = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor(drop_remainder, dtype=dtypes.bool, name="drop_remainder") constant_drop_remainder = tensor_util.constant_value( self._drop_remainder_t) # pylint: disable=protected-access if constant_drop_remainder: # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) # or `False` (explicitly retaining the remainder). # pylint: disable=g-long-lambda self._element_spec = nest.map_structure( lambda component_spec: component_spec._batch( tensor_util.constant_value(self._batch_size_t)), self._map_func.output_structure) else: self._element_spec = nest.map_structure( lambda component_spec: component_spec._batch(None), self._map_func.output_structure) # pylint: enable=protected-access if compat.forward_compatible(2019, 8, 3): variant_tensor = ged_ops.map_and_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, preserve_cardinality=True, **self._flat_structure) else: variant_tensor = ged_ops.experimental_map_and_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, preserve_cardinality=True, **self._flat_structure) super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)