def _next_func(string_handle): # pylint: disable=protected-access multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=structure.get_flat_tensor_types(self._element_spec), output_shapes=structure.get_flat_tensor_shapes( self._element_spec))) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=structure.get_flat_tensor_types(self._element_spec), output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
def at(dataset, index): """Returns the element at a specific index in a datasest. Args: dataset: A `tf.data.Dataset` to determine whether it supports random access. index: The index at which to fetch the element. Returns: A (nested) structure of values matching `tf.data.Dataset.element_spec`. Raises: UnimplementedError: If random access is not yet supported for a dataset. Currently, random access is supported for the following tf.data ops: `tf.data.Dataset.from_tensor_slices`, `tf.data.Dataset.shuffle`, `tf.data.Dataset.batch`, `tf.data.Dataset.shard`, `tf.data.Dataset.map`, and `tf.data.Dataset.range`, `tf.data.Dataset.skip`, `tf.data.Dataset.repeat`. """ # pylint: disable=protected-access return structure.from_tensor_list( dataset.element_spec, gen_experimental_dataset_ops.get_element_at_index( dataset._variant_tensor, index, output_types=structure.get_flat_tensor_types(dataset.element_spec), output_shapes=structure.get_flat_tensor_shapes( dataset.element_spec)))
def checkDatasetSpec(self, tf_value, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value) dataset_structure = structure.type_spec_from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(dataset), expected_element_structure)) self.assertEqual([dtypes.variant], structure.get_flat_tensor_types(dataset_structure)) self.assertEqual([tensor_shape.TensorShape([])], structure.get_flat_tensor_shapes(dataset_structure)) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces(round_trip_dataset, [self.evaluate(tf_value)], requires_initialization=True)
def _remote_next_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=structure.get_flat_tensor_types( self._input_dataset._element_structure), # pylint: disable=protected-access f=next_func_concrete)
def _create_iterator(self, dataset): # pylint: disable=protected-access dataset = dataset._apply_options() # Store dataset reference to ensure that dataset is alive when this iterator # is being used. For example, `tf.data.Dataset.from_generator` registers # a few py_funcs that are needed in `self._next_internal`. If the dataset # is deleted, this iterator crashes on `self.__next__(...)` call. self._dataset = dataset ds_variant = dataset._variant_tensor self._element_spec = dataset.element_spec self._flat_output_types = structure.get_flat_tensor_types( self._element_spec) self._flat_output_shapes = structure.get_flat_tensor_shapes( self._element_spec) with ops.colocate_with(ds_variant): self._iterator_resource, self._deleter = ( gen_dataset_ops.anonymous_iterator_v2( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes)) gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource) # Delete the resource when this object is deleted self._resource_deleter = IteratorResourceDeleter( handle=self._iterator_resource, device=self._device, deleter=self._deleter)
def _create_iterator(self, dataset): # pylint: disable=protected-access dataset = dataset._apply_debug_options() # Store dataset reference to ensure that dataset is alive when this iterator # is being used. For example, `tf.data.Dataset.from_generator` registers # a few py_funcs that are needed in `self._next_internal`. If the dataset # is deleted, this iterator crashes on `self.__next__(...)` call. self._dataset = dataset ds_variant = dataset._variant_tensor self._element_spec = dataset.element_spec self._flat_output_types = structure.get_flat_tensor_types( self._element_spec) self._flat_output_shapes = structure.get_flat_tensor_shapes( self._element_spec) with ops.colocate_with(ds_variant): self._iterator_resource = (gen_dataset_ops.anonymous_iterator_v3( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes)) if not context.executing_eagerly(): # Add full type information to the graph so host memory types inside # variants stay on CPU, e.g, ragged string tensors. # TODO(b/224776031) Remove this when AnonymousIterateV3 can use # (reverse) type inference and all other ops that are needed to # provide type information to the AnonymousIterateV3 also support # type inference (esp. cross-function type inference) instead of # setting the full type information manually. fulltype = type_utils.iterator_full_type_from_spec( self._element_spec) # fulltype is PRODUCT[ITERATOR[PRODUCT[...]]] assert len(fulltype.args[0].args[0].args) == len( self._flat_output_types) self._iterator_resource.op.experimental_set_type(fulltype) gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
def _dummy_tensor_fn(value_structure): """A function to create dummy tensors from `value_structure`.""" def create_dummy_tensor(feature_shape, feature_type): """Create a dummy tensor with possible batch dimensions set to 0.""" # Ideally we should set the batch dimension to 0, however as in # DistributionStrategy we don't know the batch dimension, we try to # guess it as much as possible. If the feature has unknown dimensions, we # will set them to 0. If the feature shape is already static, we guess the # first dimension as batch dimension and set it to 0. dims = [] for dim in feature_shape.dims: if dim.value is None: dims.append(tensor_shape.Dimension(0)) else: dims.append(dim) if feature_shape.is_fully_defined() and dims: dims[0] = tensor_shape.Dimension(0) # Create the dummy tensor. dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type) return dummy_tensor result = [] # pylint: disable=protected-access for feature_shape, feature_type in zip( structure.get_flat_tensor_shapes(value_structure), structure.get_flat_tensor_types(value_structure)): result.append(create_dummy_tensor(feature_shape, feature_type)) return nest.pack_sequence_as(value_structure, result)
def __init__(self, filenames): """Create a `SequenceFileDataset`. `SequenceFileDataset` allows a user to read data from a hadoop sequence file. A sequence file consists of (key value) pairs sequentially. At the moment, `org.apache.hadoop.io.Text` is the only serialization type being supported, and there is no compression support. For example: ```python tf.compat.v1.enable_eager_execution() dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq") # Prints the (key, value) pairs inside a hadoop sequence file. for key, value in dataset: print(key, value) ``` Args: filenames: A `tf.string` tensor containing one or more filenames. """ self._filenames = ops.convert_to_tensor(filenames, dtype=dtypes.string, name="filenames") variant_tensor = gen_dataset_ops.sequence_file_dataset( self._filenames, structure.get_flat_tensor_types(self._element_structure)) super(SequenceFileDataset, self).__init__(variant_tensor)
def trace_legacy_function(defun_kwargs): @function.Defun(*structure.get_flat_tensor_types( self._input_structure), **defun_kwargs) def wrapped_fn(*args): ret = wrapper_helper(*args) return structure.to_tensor_list(self._output_structure, ret) return lambda: wrapped_fn
def get_next_as_optional(self): # pylint: disable=protected-access return optional_ops._OptionalImpl( gen_dataset_ops.iterator_get_next_as_optional( self._iterator_resource, output_types=structure.get_flat_tensor_types(self.element_spec), output_shapes=structure.get_flat_tensor_shapes( self.element_spec)), self.element_spec)
def get_next_as_optional(self): # TODO(b/169442955): Investigate the need for this colocation constraint. with ops.colocate_with(self._iterator_resource): # pylint: disable=protected-access return optional_ops._OptionalImpl( gen_dataset_ops.iterator_get_next_as_optional( self._iterator_resource, output_types=structure.get_flat_tensor_types(self.element_spec), output_shapes=structure.get_flat_tensor_shapes( self.element_spec)), self.element_spec)
def _remote_next_func(string_handle): return_values = functional_ops.remote_call( target=source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=structure.get_flat_tensor_types(self._element_spec), f=next_func_concrete) fulltype = structure.full_type_from_spec(self._element_spec) for op in return_values: op.op.experimental_set_type(fulltype) return return_values
def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s representing the next element. In graph mode, you should typically call this method *once* and use its result as the input to another computation. A typical loop will then call `tf.Session.run` on the result of that computation. The loop will terminate when the `Iterator.get_next()` operation raises `tf.errors.OutOfRangeError`. The following skeleton shows how to use this method when building a training loop: ```python dataset = ... # A `tf.data.Dataset` object. iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() # Build a TensorFlow graph that does something with each element. loss = model_function(next_element) optimizer = ... # A `tf.compat.v1.train.Optimizer` object. train_op = optimizer.minimize(loss) with tf.compat.v1.Session() as sess: try: while True: sess.run(train_op) except tf.errors.OutOfRangeError: pass ``` NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g. when you are distributing different elements to multiple devices in a single step. However, a common pitfall arises when users call `Iterator.get_next()` in each iteration of their training loop. `Iterator.get_next()` adds ops to the graph, and executing each op allocates resources (including threads); as a consequence, invoking it in every iteration of a training loop causes slowdown and eventual resource exhaustion. To guard against this outcome, we log a warning when the number of uses crosses a fixed threshold of suspiciousness. Args: name: (Optional.) A name for the created operation. Returns: A nested structure of `tf.Tensor` objects. """ self._get_next_call_count += 1 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) # pylint: disable=protected-access flat_ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=structure.get_flat_tensor_types(self._structure), output_shapes=structure.get_flat_tensor_shapes(self._structure), name=name) return structure.from_tensor_list(self._structure, flat_ret)
def __init__(self, dataset=None, components=None, element_spec=None, job_token=None): """Creates a new iterator from the given dataset. If `dataset` is not specified, the iterator will be created from the given tensor components and element structure. In particular, the alternative for constructing the iterator is used when the iterator is reconstructed from it `CompositeTensor` representation. Args: dataset: A `tf.data.Dataset` object. components: Tensor components to construct the iterator from. element_spec: A nested structure of `TypeSpec` objects that represents the type specification of elements of the iterator. job_token: A token to use for reading from a tf.data service job. Data will be partitioned among all iterators using the same token. If `None`, the iterator will not read from the tf.data service. Raises: ValueError: If `dataset` is not provided and either `components` or `element_spec` is not provided. Or `dataset` is provided and either `components` and `element_spec` is provided. """ error_message = ("Either `dataset` or both `components` and " "`element_spec` need to be provided.") self._device = context.context().device_name self._job_token = job_token if dataset is None: if (components is None or element_spec is None): raise ValueError(error_message) # pylint: disable=protected-access self._element_spec = element_spec self._flat_output_types = structure.get_flat_tensor_types( self._element_spec) self._flat_output_shapes = structure.get_flat_tensor_shapes( self._element_spec) self._iterator_resource, self._deleter = components else: if (components is not None or element_spec is not None): raise ValueError(error_message) if (_device_stack_is_empty() or context.context().device_spec.device_type != "CPU"): with ops.device("/cpu:0"): self._create_iterator(dataset) else: self._create_iterator(dataset)
def at(dataset, index): """Returns the element at a specific index in a datasest. Currently, random access is supported for the following tf.data operations: - `tf.data.Dataset.from_tensor_slices`, - `tf.data.Dataset.from_tensors`, - `tf.data.Dataset.shuffle`, - `tf.data.Dataset.batch`, - `tf.data.Dataset.shard`, - `tf.data.Dataset.map`, - `tf.data.Dataset.range`, - `tf.data.Dataset.zip`, - `tf.data.Dataset.skip`, - `tf.data.Dataset.repeat`, - `tf.data.Dataset.list_files`, - `tf.data.Dataset.SSTableDataset`, - `tf.data.Dataset.concatenate`, - `tf.data.Dataset.enumerate`, - `tf.data.Dataset.parallel_map`, - `tf.data.Dataset.prefetch`, - `tf.data.Dataset.take`, - `tf.data.Dataset.cache` (in-memory only) Users can use the cache operation to enable random access for any dataset, even one comprised of transformations which are not on this list. E.g., to get the third element of a TFDS dataset: ```python ds = tfds.load("mnist", split="train").cache() elem = tf.data.Dataset.experimental.at(ds, 3) ``` Args: dataset: A `tf.data.Dataset` to determine whether it supports random access. index: The index at which to fetch the element. Returns: A (nested) structure of values matching `tf.data.Dataset.element_spec`. Raises: UnimplementedError: If random access is not yet supported for a dataset. """ # pylint: disable=protected-access return structure.from_tensor_list( dataset.element_spec, gen_experimental_dataset_ops.get_element_at_index( dataset._variant_tensor, index, output_types=structure.get_flat_tensor_types(dataset.element_spec), output_shapes=structure.get_flat_tensor_shapes( dataset.element_spec)))
def get_value(self, name=None): # TODO(b/110122868): Consolidate the restructuring logic with similar logic # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: with ops.colocate_with(self._variant_tensor): result = gen_dataset_ops.optional_get_value( self._variant_tensor, name=scope, output_types=structure.get_flat_tensor_types(self._element_spec), output_shapes=structure.get_flat_tensor_shapes(self._element_spec)) # NOTE: We do not colocate the deserialization of composite tensors # because not all ops are guaranteed to have non-GPU kernels. return structure.from_tensor_list(self._element_spec, result)
def __init__(self, dataset, path, shard_func, compression): dataset, shard_func, use_shard_func, path = _set_save_dataset_attributes( dataset, shard_func, path) variant_tensor = gen_experimental_dataset_ops.save_dataset_v2( dataset._variant_tensor, # pylint: disable=protected-access path=path, shard_func_other_args=shard_func.captured_inputs, shard_func=shard_func, use_shard_func=use_shard_func, compression=compression, output_types=structure.get_flat_tensor_types(dataset.element_spec), output_shapes=structure.get_flat_tensor_shapes(dataset.element_spec), ) super(_SaveDataset, self).__init__(dataset, variant_tensor)
def get_value(self, name=None): # TODO(b/110122868): Consolidate the restructuring logic with similar logic # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: return structure.from_tensor_list( self._value_structure, gen_dataset_ops.optional_get_value( self._variant_tensor, name=scope, output_types=structure.get_flat_tensor_types( self._value_structure), output_shapes=structure.get_flat_tensor_shapes( self._value_structure)))
def testFlatStructure(self, value_fn, expected_structure, expected_types, expected_shapes): value = value_fn() s = structure.type_spec_from_value(value) self.assertIsInstance(s, expected_structure) flat_types = structure.get_flat_tensor_types(s) self.assertEqual(expected_types, flat_types) flat_shapes = structure.get_flat_tensor_shapes(s) self.assertLen(flat_shapes, len(expected_shapes)) for expected, actual in zip(expected_shapes, flat_shapes): if expected is None: self.assertEqual(actual.ndims, None) else: self.assertEqual(actual.as_list(), expected)
def __init__(self, dataset=None, components=None, element_structure=None): """Creates a new iterator from the given dataset. If `dataset` is not specified, the iterator will be created from the given tensor components and element structure. In particular, the alternative for constructing the iterator is used when the iterator is reconstructed from it `CompositeTensor` representation. Args: dataset: A `tf.data.Dataset` object. components: Tensor components to construct the iterator from. element_structure: A nested structure of `TypeSpec` objects that represents the type specification elements of the iterator. Raises: ValueError: If `dataset` is not provided and either `components` or `element_structure` is not provided. Or `dataset` is provided and either `components` and `element_structure` is provided. """ error_message = "Either `dataset` or both `components` and " "`element_structure` need to be provided." self._device = context.context().device_name if dataset is None: if (components is None or element_structure is None): raise ValueError(error_message) # pylint: disable=protected-access self._structure = element_structure self._flat_output_types = structure.get_flat_tensor_types( self._structure) self._flat_output_shapes = structure.get_flat_tensor_shapes( self._structure) self._iterator_resource, self._deleter = components # Delete the resource when this object is deleted self._resource_deleter = IteratorResourceDeleter( handle=self._iterator_resource, device=self._device, deleter=self._deleter) else: if (components is not None or element_structure is not None): raise ValueError(error_message) if (_device_stack_is_empty() or context.context().device_spec.device_type != "CPU"): with ops.device("/cpu:0"): self._create_iterator(dataset) else: self._create_iterator(dataset)
def _remote_next_func(string_handle): return_values = functional_ops.remote_call( target=source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=structure.get_flat_tensor_types(self._element_spec), f=next_func_concrete) # Add full type information to the graph so that the RemoteCall op # can determine for each of its outputs whether or not they are ragged # tensors (or other types that use variants) that contain strings # (or other host memory types). Then RemoteCall can # appropriately set AllocatorAttributes to control copies so # strings/host memory types stay on CPU. fulltype = structure.full_type_from_spec(self._element_spec) for return_value in return_values: return_value.op.experimental_set_type(fulltype) return return_values
def __init__(self, iterator_resource, initializer, output_types, output_shapes, output_classes): """Creates a new iterator from the given iterator resource. Note: Most users will not call this initializer directly, and will instead use `Dataset.make_initializable_iterator()` or `Dataset.make_one_shot_iterator()`. Args: iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the iterator. initializer: A `tf.Operation` that should be run to initialize this iterator. output_types: A (nested) structure of `tf.DType` objects corresponding to each component of an element of this iterator. output_shapes: A (nested) structure of `tf.TensorShape` objects corresponding to each component of an element of this iterator. output_classes: A (nested) structure of Python `type` objects corresponding to each component of an element of this iterator. Raises: TypeError: If `output_types`, `output_shapes`, or `output_classes` is not specified. """ self._iterator_resource = iterator_resource self._initializer = initializer if (output_types is None or output_shapes is None or output_classes is None): raise ValueError( "All of `output_types`, `output_shapes`, and `output_classes` " "must be specified to create an iterator. Got " f"`output_types` = {output_types!r}, " f"`output_shapes` = {output_shapes!r}, " f"`output_classes` = {output_classes!r}.") self._element_spec = structure.convert_legacy_structure( output_types, output_shapes, output_classes) self._flat_tensor_shapes = structure.get_flat_tensor_shapes( self._element_spec) self._flat_tensor_types = structure.get_flat_tensor_types( self._element_spec) self._string_handle = gen_dataset_ops.iterator_to_string_handle( self._iterator_resource) self._get_next_call_count = 0 ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
def uncompress(element, output_spec): """Uncompress a compressed dataset element. Args: element: A scalar variant tensor to uncompress. The element should have been created by calling `compress`. output_spec: A nested structure of `tf.TypeSpec` representing the type(s) of the uncompressed element. Returns: The uncompressed element. """ flat_types = structure.get_flat_tensor_types(output_spec) flat_shapes = structure.get_flat_tensor_shapes(output_spec) tensor_list = ged_ops.uncompress_element(element, output_types=flat_types, output_shapes=flat_shapes) return structure.from_tensor_list(output_spec, tensor_list)
def __init__(self, dataset=None, components=None, element_spec=None): """Creates a new iterator from the given dataset. If `dataset` is not specified, the iterator will be created from the given tensor components and element structure. In particular, the alternative for constructing the iterator is used when the iterator is reconstructed from it `CompositeTensor` representation. Args: dataset: A `tf.data.Dataset` object. components: Tensor components to construct the iterator from. element_spec: A (nested) structure of `TypeSpec` objects that represents the type specification of elements of the iterator. Raises: ValueError: If `dataset` is not provided and either `components` or `element_spec` is not provided. Or `dataset` is provided and either `components` and `element_spec` is provided. """ super(OwnedIterator, self).__init__() if dataset is None: if (components is None or element_spec is None): raise ValueError( "When `dataset` is not provided, both `components` and " "`element_spec` must be specified.") # pylint: disable=protected-access self._element_spec = element_spec self._flat_output_types = structure.get_flat_tensor_types( self._element_spec) self._flat_output_shapes = structure.get_flat_tensor_shapes( self._element_spec) if use_anonymous_iterator_v3(): self._iterator_resource, = components else: self._iterator_resource, self._deleter = components else: if (components is not None or element_spec is not None): raise ValueError( "When `dataset` is provided, `element_spec` and `components` must " "not be specified.") self._create_iterator(dataset) self._get_next_call_count = 0
def _create_iterator(self, dataset): # pylint: disable=protected-access dataset = dataset._apply_options() ds_variant = dataset._variant_tensor self._element_spec = dataset.element_spec self._flat_output_types = structure.get_flat_tensor_types( self._element_spec) self._flat_output_shapes = structure.get_flat_tensor_shapes( self._element_spec) with ops.colocate_with(ds_variant): self._iterator_resource, self._deleter = ( gen_dataset_ops.anonymous_iterator_v2( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes)) gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource) # Delete the resource when this object is deleted self._resource_deleter = IteratorResourceDeleter( handle=self._iterator_resource, device=self._device, deleter=self._deleter)
def get_next_as_optional(iterator): """Returns a `tf.experimental.Optional` with the next element of the iterator. If the iterator has reached the end of the sequence, the returned `tf.experimental.Optional` will have no value. Args: iterator: A `tf.data.Iterator`. Returns: A `tf.experimental.Optional` object which either contains the next element of the iterator (if it exists) or no value. """ # pylint: disable=protected-access return optional_ops._OptionalImpl( gen_dataset_ops.iterator_get_next_as_optional( iterator._iterator_resource, output_types=structure.get_flat_tensor_types(iterator.element_spec), output_shapes=structure.get_flat_tensor_shapes( iterator.element_spec)), iterator.element_spec)
def get_next_as_optional(iterator): """Returns an `Optional` that contains the next value from the iterator. If `iterator` has reached the end of the sequence, the returned `Optional` will have no value. Args: iterator: A `tf.compat.v1.data.Iterator` object. Returns: An `Optional` object representing the next value from the iterator (if it has one) or no value. """ # pylint: disable=protected-access return optional_ops._OptionalImpl( gen_dataset_ops.iterator_get_next_as_optional( iterator._iterator_resource, output_types=structure.get_flat_tensor_types(iterator.element_spec), output_shapes=structure.get_flat_tensor_shapes( iterator.element_spec)), iterator.element_spec)
def testOptionalStructure(self, tf_value_fn, expected_value_structure): tf_value = tf_value_fn() opt = optional_ops.Optional.from_value(tf_value) self.assertTrue( structure.are_compatible(opt.value_structure, expected_value_structure)) opt_structure = structure.type_spec_from_value(opt) self.assertIsInstance(opt_structure, optional_ops.OptionalStructure) self.assertTrue(structure.are_compatible(opt_structure, opt_structure)) self.assertTrue( structure.are_compatible(opt_structure._value_structure, expected_value_structure)) self.assertEqual([dtypes.variant], structure.get_flat_tensor_types(opt_structure)) self.assertEqual([tensor_shape.scalar()], structure.get_flat_tensor_shapes(opt_structure)) # All OptionalStructure objects are not compatible with a non-optional # value. non_optional_structure = structure.type_spec_from_value( constant_op.constant(42.0)) self.assertFalse( opt_structure.is_compatible_with(non_optional_structure)) # Assert that the optional survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_opt = opt_structure._from_tensor_list( opt_structure._to_tensor_list(opt)) if isinstance(tf_value, optional_ops.Optional): self._assertElementValueEqual( self.evaluate(tf_value.get_value()), self.evaluate(round_trip_opt.get_value().get_value())) else: self._assertElementValueEqual( self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value()))
def from_string_handle(string_handle, output_types, output_shapes=None, output_classes=None): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose between concrete iterators by feeding a value in a `tf.Session.run` call. In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you would feed it with the value of `tf.data.Iterator.string_handle` in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in each step as follows: ```python train_iterator = tf.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.compat.v1.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) if _device_stack_is_empty(): with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=structure.get_flat_tensor_types( output_structure), output_shapes=structure.get_flat_tensor_shapes( output_structure)) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=structure.get_flat_tensor_types(output_structure), output_shapes=structure.get_flat_tensor_shapes( output_structure)) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def from_structure(output_types, output_shapes=None, shared_name=None, output_classes=None): """Creates a new, uninitialized `Iterator` with the given structure. This iterator-constructing method can be used to create an iterator that is reusable with many different datasets. The returned iterator is not bound to a particular dataset, and it has no `initializer`. To initialize the iterator, run the operation returned by `Iterator.make_initializer(dataset)`. The following is an example ```python iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) dataset_range = Dataset.range(10) range_initializer = iterator.make_initializer(dataset_range) dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) evens_initializer = iterator.make_initializer(dataset_evens) # Define a model based on the iterator; in this example, the model_fn # is expected to take scalar tf.int64 Tensors as input (see # the definition of 'iterator' above). prediction, loss = model_fn(iterator.get_next()) # Train for `num_epochs`, where for each epoch, we first iterate over # dataset_range, and then iterate over dataset_evens. for _ in range(num_epochs): # Initialize the iterator to `dataset_range` sess.run(range_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break # Initialize the iterator to `dataset_evens` sess.run(evens_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break ``` Args: output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. shared_name: (Optional.) If non-empty, this iterator will be shared under the given name across multiple sessions that share the same devices (e.g. when using a remote server). output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. Raises: TypeError: If the structures of `output_shapes` and `output_types` are not the same. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) if shared_name is None: shared_name = "" if _device_stack_is_empty(): with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, output_types=structure.get_flat_tensor_types( output_structure), output_shapes=structure.get_flat_tensor_shapes( output_structure)) else: iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, output_types=structure.get_flat_tensor_types(output_structure), output_shapes=structure.get_flat_tensor_shapes( output_structure)) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)