def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: classes = sparse.get_classes(expected) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), types, shapes, sparse.get_classes(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def make_initializer(self, dataset, name=None): """Returns a `tf.Operation` that initializes this iterator on `dataset`. Args: dataset: A `Dataset` with compatible structure to this iterator. name: (Optional.) A name for the created operation. Returns: A `tf.Operation` that can be run to initialize this iterator on the given `dataset`. Raises: TypeError: If `dataset` and this iterator do not have a compatible element structure. """ with ops.name_scope(name, "make_initializer") as name: nest.assert_same_structure(self._output_types, dataset.output_types) nest.assert_same_structure(self._output_shapes, dataset.output_shapes) for iterator_dtype, dataset_dtype in zip( nest.flatten(self._output_types), nest.flatten(dataset.output_types)): if iterator_dtype != dataset_dtype: raise TypeError( "Expected output types %r but got dataset with output types %r." % (self._output_types, dataset.output_types)) for iterator_shape, dataset_shape in zip( nest.flatten(self._output_shapes), nest.flatten(dataset.output_shapes)): if not iterator_shape.is_compatible_with(dataset_shape): raise TypeError("Expected output shapes compatible with %r but got " "dataset with output shapes %r." % (self._output_shapes, dataset.output_shapes)) with ops.colocate_with(self._iterator_resource): return gen_dataset_ops.make_iterator( dataset._as_variant_tensor(), self._iterator_resource, name=name) # pylint: disable=protected-access
def from_string_handle(string_handle, output_types, output_shapes=None): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose between concrete iterators by feeding a value in a @{tf.Session.run} call. In that case, `string_handle` would a @{tf.placeholder}, and you would feed it with the value of @{tf.data.Iterator.string_handle} in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in each step as follows: ```python train_iterator = tf.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`) objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`) component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. Returns: An `Iterator`. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) nest.assert_same_structure(output_types, output_shapes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) iterator_resource = gen_dataset_ops.iterator_from_string_handle( string_handle, output_types=nest.flatten(sparse.unwrap_sparse_types(output_types)), output_shapes=nest.flatten(output_shapes)) return Iterator(iterator_resource, None, output_types, output_shapes)
def __init__(self, dataset, output_types, output_shapes=None): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: * `dataset.output_types` must be the same as `output_types` module nesting. * Each shape in `dataset.output_shapes` must be compatible with each shape in `output_shapes` (if given). Note: This helper permits "unsafe casts" for shapes, equivalent to using `tf.Tensor.set_shape()` where domain-specific knowledge is available. Args: dataset: A `Dataset` object. output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ super(_RestructuredDataset, self).__init__() self._dataset = dataset # Validate that the types are compatible. output_types = nest.map_structure(dtypes.as_dtype, output_types) flat_original_types = nest.flatten(dataset.output_types) flat_new_types = nest.flatten(output_types) if flat_original_types != flat_new_types: raise ValueError( "Dataset with output types %r cannot be restructured to have output " "types %r" % (dataset.output_types, output_types)) self._output_types = output_types if output_shapes is None: # Inherit shapes from the original `dataset`. self._output_shapes = nest.pack_sequence_as(output_types, nest.flatten( dataset.output_shapes)) else: # Validate that the shapes are compatible. nest.assert_same_structure(output_types, output_shapes) flat_original_shapes = nest.flatten(dataset.output_shapes) flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) for original_shape, new_shape in zip(flat_original_shapes, flat_new_shapes): if not original_shape.is_compatible_with(new_shape): raise ValueError( "Dataset with output shapes %r cannot be restructured to have " "incompatible output shapes %r" % (dataset.output_shapes, output_shapes)) self._output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes)
def __eq__(self, other): if not isinstance(other, NestedStructure): return False try: # pylint: disable=protected-access nest.assert_same_structure(self._nested_structure, other._nested_structure) except (ValueError, TypeError): return False return nest.flatten(self._nested_structure) == nest.flatten( other._nested_structure)
def __init__(self, variant_tensor, output_shapes, output_types, output_classes): # TODO(b/110122868): Consolidate the structure validation logic with the # similar logic in `Iterator.from_structure()` and # `Dataset.from_generator()`. output_types = nest.map_structure(dtypes.as_dtype, output_types) output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) nest.assert_same_structure(output_types, output_shapes) nest.assert_same_structure(output_types, output_classes) self._variant_tensor = variant_tensor self._output_shapes = output_shapes self._output_types = output_types self._output_classes = output_classes
def is_compatible_with(self, other): if not isinstance(other, NestedStructure): return False try: # pylint: disable=protected-access nest.assert_same_structure(self._nested_structure, other._nested_structure) except (ValueError, TypeError): return False return all( substructure.is_compatible_with(other_substructure) for substructure, other_substructure in zip( nest.flatten(self._nested_structure), nest.flatten(other._nested_structure)))
def _compareOutputToExpected(self, result_values, expected_values, assert_items_equal): if assert_items_equal: # TODO(shivaniagrawal): add support for nested elements containing sparse # tensors when needed. self.assertItemsEqual(result_values, expected_values) return for i in range(len(result_values)): nest.assert_same_structure(result_values[i], expected_values[i]) for result_value, expected_value in zip( nest.flatten(result_values[i]), nest.flatten(expected_values[i])): if sparse_tensor.is_sparse(result_value): self.assertSparseValuesEqual(result_value, expected_value) else: self.assertAllEqual(result_value, expected_value)
def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = (((7, 8), 9), 10, (11, 12)) structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) nest.assert_same_structure(structure1, structure1_plus1) self.assertAllEqual( [2, 3, 4, 5, 6, 7], nest.flatten(structure1_plus1)) structure1_plus_structure2 = nest.map_structure( lambda x, y: x + y, structure1, structure2) self.assertEqual( (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), structure1_plus_structure2) self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) with self.assertRaisesRegexp(TypeError, "callable"): nest.map_structure("bad", structure1_plus1) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, 3, (3,)) with self.assertRaisesRegexp(TypeError, "same sequence type"): nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5}) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) with self.assertRaisesRegexp(ValueError, "same nested structure"): nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), check_types=False) with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): nest.map_structure(lambda x: None, structure1, foo="a") with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), sparse.get_sparse_types(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def testAssertSameStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) structure_different_num_elements = ("spam", "eggs") structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) nest.assert_same_structure(structure1, structure2) nest.assert_same_structure("abc", 1.0) nest.assert_same_structure("abc", np.array([0, 1])) nest.assert_same_structure("abc", constant_op.constant([0, 1])) with self.assertRaisesRegexp(ValueError, "don't have the same number of elements"): nest.assert_same_structure(structure1, structure_different_num_elements) with self.assertRaisesRegexp(ValueError, "don't have the same number of elements"): nest.assert_same_structure((0, 1), np.array([0, 1])) with self.assertRaisesRegexp(ValueError, "don't have the same number of elements"): nest.assert_same_structure(0, (0, 1)) with self.assertRaisesRegexp(ValueError, "don't have the same nested structure"): nest.assert_same_structure(structure1, structure_different_nesting) named_type_0 = collections.namedtuple("named_0", ("a", "b")) named_type_1 = collections.namedtuple("named_1", ("a", "b")) self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), named_type_0("a", "b")) nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b")) self.assertRaises(TypeError, nest.assert_same_structure, named_type_0(3, 4), named_type_1(3, 4)) with self.assertRaisesRegexp(ValueError, "don't have the same nested structure"): nest.assert_same_structure(named_type_0(3, 4), named_type_0((3,), 4)) with self.assertRaisesRegexp(ValueError, "don't have the same nested structure"): nest.assert_same_structure(((3,), 4), (3, (4,))) structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)} with self.assertRaisesRegexp(TypeError, "don't have the same sequence type"): nest.assert_same_structure(structure1, structure1_list) nest.assert_same_structure(structure1, structure2, check_types=False) nest.assert_same_structure(structure1, structure1_list, check_types=False)
def from_structure(output_types, output_shapes=None, shared_name=None, output_classes=None): """Creates a new, uninitialized `Iterator` with the given structure. This iterator-constructing method can be used to create an iterator that is reusable with many different datasets. The returned iterator is not bound to a particular dataset, and it has no `initializer`. To initialize the iterator, run the operation returned by `Iterator.make_initializer(dataset)`. The following is an example ```python iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) dataset_range = Dataset.range(10) range_initializer = iterator.make_initializer(dataset_range) dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) evens_initializer = iterator.make_initializer(dataset_evens) # Define a model based on the iterator; in this example, the model_fn # is expected to take scalar tf.int64 Tensors as input (see # the definition of 'iterator' above). prediction, loss = model_fn(iterator.get_next()) # Train for `num_epochs`, where for each epoch, we first iterate over # dataset_range, and then iterate over dataset_evens. for _ in range(num_epochs): # Initialize the iterator to `dataset_range` sess.run(range_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break # Initialize the iterator to `dataset_evens` sess.run(evens_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break ``` Args: output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. shared_name: (Optional.) If non-empty, this iterator will be shared under the given name across multiple sessions that share the same devices (e.g. when using a remote server). output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. Raises: TypeError: If the structures of `output_shapes` and `output_types` are not the same. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) if shared_name is None: shared_name = "" iterator_resource = gen_dataset_ops.iterator( container="", shared_name=shared_name, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def testAssertSameStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) structure_different_num_elements = ("spam", "eggs") structure_different_nesting = (((1, 2), 3), 4, 5, (6, )) nest.assert_same_structure(structure1, structure2) nest.assert_same_structure("abc", 1.0) nest.assert_same_structure("abc", np.array([0, 1])) nest.assert_same_structure("abc", constant_op.constant([0, 1])) with self.assertRaisesRegexp(ValueError, "don't have the same number of elements"): nest.assert_same_structure(structure1, structure_different_num_elements) with self.assertRaisesRegexp(ValueError, "don't have the same number of elements"): nest.assert_same_structure((0, 1), np.array([0, 1])) with self.assertRaisesRegexp(ValueError, "don't have the same number of elements"): nest.assert_same_structure(0, (0, 1)) with self.assertRaisesRegexp(ValueError, "don't have the same nested structure"): nest.assert_same_structure(structure1, structure_different_nesting) named_type_0 = collections.namedtuple("named_0", ("a", "b")) named_type_1 = collections.namedtuple("named_1", ("a", "b")) self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), named_type_0("a", "b")) nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b")) self.assertRaises(TypeError, nest.assert_same_structure, named_type_0(3, 4), named_type_1(3, 4)) with self.assertRaisesRegexp(ValueError, "don't have the same nested structure"): nest.assert_same_structure(named_type_0(3, 4), named_type_0((3, ), 4)) with self.assertRaisesRegexp(ValueError, "don't have the same nested structure"): nest.assert_same_structure(((3, ), 4), (3, (4, ))) structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)} with self.assertRaisesRegexp(TypeError, "don't have the same sequence type"): nest.assert_same_structure(structure1, structure1_list) nest.assert_same_structure(structure1, structure2, check_types=False) nest.assert_same_structure(structure1, structure1_list, check_types=False)
def from_string_handle(string_handle, output_types, output_shapes=None, output_classes=None): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose between concrete iterators by feeding a value in a `tf.Session.run` call. In that case, `string_handle` would be a `tf.placeholder`, and you would feed it with the value of `tf.data.Iterator.string_handle` in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in each step as follows: ```python train_iterator = tf.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure_lib.convert_legacy_structure( output_types, output_shapes, output_classes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) # pylint: disable=protected-access if compat.forward_compatible(2018, 8, 3): if _device_stack_is_empty(): with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle( string_handle, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) # pylint: enable=protected-access return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def __init__(self, dataset, output_types, output_shapes=None, output_classes=None, allow_unsafe_cast=False): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: * `dataset.output_types` must be the same as `output_types` module nesting. * Each shape in `dataset.output_shapes` must be compatible with each shape in `output_shapes` (if given). Note: This helper permits "unsafe casts" for shapes, equivalent to using `tf.Tensor.set_shape()` where domain-specific knowledge is available. Args: dataset: A `Dataset` object. output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. output_classes: (Optional.) A nested structure of class types. If omitted, the class types will be inherited from `dataset`. allow_unsafe_cast: (Optional.) If `True`, the caller may switch the reported output types and shapes of the restructured dataset, e.g. to switch a sparse tensor represented as `tf.variant` to its user-visible type and shape. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ self._input_dataset = dataset input_types = dataset_ops.get_legacy_output_types(dataset) if not allow_unsafe_cast: # Validate that the types are compatible. output_types = nest.map_structure(dtypes.as_dtype, output_types) flat_original_types = nest.flatten(input_types) flat_new_types = nest.flatten(output_types) if flat_original_types != flat_new_types: raise ValueError( "Dataset with output types %r cannot be restructured to have " "output types %r" % (dataset_ops.get_legacy_output_types(dataset), output_types)) input_shapes = dataset_ops.get_legacy_output_shapes(dataset) if output_shapes is None: # Inherit shapes from the original `dataset`. output_shapes = nest.pack_sequence_as( output_types, nest.flatten(input_shapes)) else: if not allow_unsafe_cast: # Validate that the shapes are compatible. nest.assert_same_structure(output_types, output_shapes) flat_original_shapes = nest.flatten(input_shapes) flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) for original_shape, new_shape in zip(flat_original_shapes, flat_new_shapes): if not original_shape.is_compatible_with(new_shape): raise ValueError( "Dataset with output shapes %r cannot be restructured to have " "incompatible output shapes %r" % (input_shapes, output_shapes)) output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) input_classes = dataset_ops.get_legacy_output_classes(dataset) if output_classes is None: # Inherit class types from the original `dataset`. output_classes = nest.pack_sequence_as( output_types, nest.flatten(input_classes)) self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
def __init__(self, dataset, output_types, output_shapes=None, output_classes=None, allow_unsafe_cast=False): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: * `dataset.output_types` must be the same as `output_types` module nesting. * Each shape in `dataset.output_shapes` must be compatible with each shape in `output_shapes` (if given). Note: This helper permits "unsafe casts" for shapes, equivalent to using `tf.Tensor.set_shape()` where domain-specific knowledge is available. Args: dataset: A `Dataset` object. output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. output_classes: (Optional.) A nested structure of class types. If omitted, the class types will be inherited from `dataset`. allow_unsafe_cast: (Optional.) If `True`, the caller may switch the reported output types and shapes of the restructured dataset, e.g. to switch a sparse tensor represented as `tf.variant` to its user-visible type and shape. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ super(_RestructuredDataset, self).__init__() self._input_dataset = dataset if not allow_unsafe_cast: # Validate that the types are compatible. output_types = nest.map_structure(dtypes.as_dtype, output_types) flat_original_types = nest.flatten(dataset.output_types) flat_new_types = nest.flatten(output_types) if flat_original_types != flat_new_types: raise ValueError( "Dataset with output types %r cannot be restructured to have " "output types %r" % (dataset.output_types, output_types)) self._output_types = output_types if output_shapes is None: # Inherit shapes from the original `dataset`. self._output_shapes = nest.pack_sequence_as(output_types, nest.flatten( dataset.output_shapes)) else: if not allow_unsafe_cast: # Validate that the shapes are compatible. nest.assert_same_structure(output_types, output_shapes) flat_original_shapes = nest.flatten(dataset.output_shapes) flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) for original_shape, new_shape in zip(flat_original_shapes, flat_new_shapes): if not original_shape.is_compatible_with(new_shape): raise ValueError( "Dataset with output shapes %r cannot be restructured to have " "incompatible output shapes %r" % (dataset.output_shapes, output_shapes)) self._output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: # Inherit class types from the original `dataset`. self._output_classes = nest.pack_sequence_as(output_types, nest.flatten( dataset.output_classes)) else: self._output_classes = output_classes
def make_initializer(self, dataset, name=None): """Returns a `tf.Operation` that initializes this iterator on `dataset`. Args: dataset: A `Dataset` whose `element_spec` if compatible with this iterator. name: (Optional.) A name for the created operation. Returns: A `tf.Operation` that can be run to initialize this iterator on the given `dataset`. Raises: TypeError: If `dataset` and this iterator do not have a compatible `element_spec`. """ with ops.name_scope(name, "make_initializer") as name: # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due # to that creating a circular dependency. # pylint: disable=protected-access dataset_output_types = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_types( ), dataset.element_spec) dataset_output_shapes = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_shapes( ), dataset.element_spec) dataset_output_classes = nest.map_structure( lambda component_spec: component_spec. _to_legacy_output_classes(), dataset.element_spec) # pylint: enable=protected-access nest.assert_same_structure(self.output_types, dataset_output_types) nest.assert_same_structure(self.output_shapes, dataset_output_shapes) for iterator_class, dataset_class in zip( nest.flatten(self.output_classes), nest.flatten(dataset_output_classes)): if iterator_class is not dataset_class: raise TypeError( f"Expected output classes {self.output_classes!r} but got " f"dataset with output classes {dataset_output_classes!r}." ) for iterator_dtype, dataset_dtype in zip( nest.flatten(self.output_types), nest.flatten(dataset_output_types)): if iterator_dtype != dataset_dtype: raise TypeError( f"Expected output types {self.output_types!r} but got dataset " f"with output types {dataset_output_types!r}.") for iterator_shape, dataset_shape in zip( nest.flatten(self.output_shapes), nest.flatten(dataset_output_shapes)): if not iterator_shape.is_compatible_with(dataset_shape): raise TypeError( f"Expected output shapes compatible with {self.output_shapes!r} " f"but got dataset with output shapes {dataset_output_shapes!r}." ) # TODO(b/169442955): Investigate the need for this colocation constraint. with ops.colocate_with(self._iterator_resource): # pylint: disable=protected-access return gen_dataset_ops.make_iterator(dataset._variant_tensor, self._iterator_resource, name=name)
def from_structure(output_types, output_shapes=None, shared_name=None, output_classes=None): """Creates a new, uninitialized `Iterator` with the given structure. This iterator-constructing method can be used to create an iterator that is reusable with many different datasets. The returned iterator is not bound to a particular dataset, and it has no `initializer`. To initialize the iterator, run the operation returned by `Iterator.make_initializer(dataset)`. The following is an example ```python iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) dataset_range = Dataset.range(10) range_initializer = iterator.make_initializer(dataset_range) dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) evens_initializer = iterator.make_initializer(dataset_evens) # Define a model based on the iterator; in this example, the model_fn # is expected to take scalar tf.int64 Tensors as input (see # the definition of 'iterator' above). prediction, loss = model_fn(iterator.get_next()) # Train for `num_epochs`, where for each epoch, we first iterate over # dataset_range, and then iterate over dataset_evens. for _ in range(num_epochs): # Initialize the iterator to `dataset_range` sess.run(range_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break # Initialize the iterator to `dataset_evens` sess.run(evens_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break ``` Args: output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. shared_name: (Optional.) If non-empty, this iterator will be shared under the given name across multiple sessions that share the same devices (e.g. when using a remote server). output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. Raises: TypeError: If the structures of `output_shapes` and `output_types` are not the same. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) if shared_name is None: shared_name = "" if compat.forward_compatible(2018, 8, 3): if _device_stack_is_empty(): with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) else: iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) else: iterator_resource = gen_dataset_ops.iterator( container="", shared_name=shared_name, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def from_string_handle(string_handle, output_types, output_shapes=None, output_classes=None): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose between concrete iterators by feeding a value in a `tf.Session.run` call. In that case, `string_handle` would be a `tf.placeholder`, and you would feed it with the value of `tf.data.Iterator.string_handle` in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in each step as follows: ```python train_iterator = tf.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) if compat.forward_compatible(2018, 8, 3): if _device_stack_is_empty(): with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle( string_handle, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)