def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s containing the next element. Args: name: (Optional.) A name for the created operation. Returns: A nested structure of `tf.Tensor` objects. """ self._get_next_call_count += 1 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=nest.flatten( sparse.as_dense_types( self._output_types, self._output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes( self._output_shapes, self._output_classes)), name=name)), self._output_types, self._output_shapes, self._output_classes)
def testIndefiniteRepeatShapeInference(self): dataset = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=None, batch_size=32) for shape, clazz in zip(nest.flatten(dataset.output_shapes), nest.flatten(dataset.output_classes)): if issubclass(clazz, ops.Tensor): self.assertEqual(32, shape[0])
def _as_variant_tensor(self): return gen_dataset_ops.ignore_errors_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes)), output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)))
def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" tensor_batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") flattened = _RestructuredDataset( dataset, tuple(nest.flatten(dataset.output_types)), output_classes=tuple(nest.flatten(dataset.output_classes))) def _predicate(*xs): """Return `True` if this element is a full batch.""" # Extract the dynamic batch size from the first component of the flattened # batched element. first_component = xs[0] first_component_batch_size = array_ops.shape( first_component, out_type=dtypes.int64)[0] return math_ops.equal(first_component_batch_size, tensor_batch_size) filtered = flattened.filter(_predicate) maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) def _set_first_dimension(shape): return shape.merge_with( tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) known_shapes = nest.map_structure(_set_first_dimension, dataset.output_shapes) return _RestructuredDataset( filtered, dataset.output_types, known_shapes, output_classes=dataset.output_classes)
def assertDatasetsEqual(self, dataset1, dataset2): """Checks that datasets are equal. Supports both graph and eager mode.""" self.assertTrue(dataset_ops.get_structure(dataset1).is_compatible_with( dataset_ops.get_structure(dataset2))) self.assertTrue(dataset_ops.get_structure(dataset2).is_compatible_with( dataset_ops.get_structure(dataset1))) flattened_types = nest.flatten( dataset_ops.get_legacy_output_types(dataset1)) next1 = self.getNext(dataset1) next2 = self.getNext(dataset2) while True: try: op1 = self.evaluate(next1()) except errors.OutOfRangeError: with self.assertRaises(errors.OutOfRangeError): self.evaluate(next2()) break op2 = self.evaluate(next2()) op1 = nest.flatten(op1) op2 = nest.flatten(op2) assert len(op1) == len(op2) for i in range(len(op1)): if sparse_tensor.is_sparse(op1[i]): self.assertSparseValuesEqual(op1[i], op2[i]) elif flattened_types[i] == dtypes.string: self.assertAllEqual(op1[i], op2[i]) else: self.assertAllClose(op1[i], op2[i])
def testFlattenAndPack(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8)) flat = ["a", "b", "c", "d", "e", "f", "g", "h"] self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) self.assertEqual( nest.pack_sequence_as(structure, flat), (("a", "b"), "c", ("d", "e", ("f", "g"), "h"))) point = collections.namedtuple("Point", ["x", "y"]) structure = (point(x=4, y=2), ((point(x=1, y=0),),)) flat = [4, 2, 1, 0] self.assertEqual(nest.flatten(structure), flat) restructured_from_flat = nest.pack_sequence_as(structure, flat) self.assertEqual(restructured_from_flat, structure) self.assertEqual(restructured_from_flat[0].x, 4) self.assertEqual(restructured_from_flat[0].y, 2) self.assertEqual(restructured_from_flat[1][0][0].x, 1) self.assertEqual(restructured_from_flat[1][0][0].y, 0) self.assertEqual([5], nest.flatten(5)) self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) self.assertEqual( np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): nest.pack_sequence_as("scalar", [4, 5]) with self.assertRaisesRegexp(TypeError, "flat_sequence"): nest.pack_sequence_as([4, 5], "bad_sequence") with self.assertRaises(ValueError): nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
def __init__(self, dataset): """Creates a new iterator over the given dataset. For example: ```python dataset = tf.contrib.data.Dataset.range(4) for x in Iterator(dataset): print(x) ``` Args: dataset: A `tf.contrib.data.Dataset` object. Raises: RuntimeError: When invoked without eager execution enabled. """ if not context.in_eager_mode(): raise RuntimeError( "{} objects only make sense when eager execution is enabled".format( type(self))) ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access self._output_types = dataset.output_types self._flat_output_types = nest.flatten(dataset.output_types) self._flat_output_shapes = nest.flatten(dataset.output_shapes) self._resource = gen_dataset_ops.iterator( container="", shared_name=_iterator_shared_name(), output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource)
def materialize(self, shared_name=None, container=None): """Materialize creates a MaterializedIndexedDataset. IndexedDatasets can be combined through operations such as TBD. Therefore, they are only materialized when absolutely required. Args: shared_name: a string for the shared name to use for the resource. container: a string for the container to store the resource. Returns: A MaterializedIndexedDataset. """ if container is None: container = "" if shared_name is None: shared_name = "" materialized_resource = ( ged_ops.experimental_materialized_index_dataset_handle( container=container, shared_name=shared_name, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( sparse.as_dense_types(self.output_shapes, self.output_classes)))) with ops.colocate_with(materialized_resource): materializer = ged_ops.experimental_indexed_dataset_materialize( self._as_variant_tensor(), materialized_resource) return MaterializedIndexedDataset(materialized_resource, materializer, self.output_classes, self.output_types, self.output_shapes)
def from_value(value): """Returns an `Optional` that wraps the given value. Args: value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects. Returns: An `Optional` that wraps `value`. """ # TODO(b/110122868): Consolidate this destructuring logic with the # similar code in `Dataset.from_tensors()`. with ops.name_scope("optional") as scope: with ops.name_scope("value"): value = nest.pack_sequence_as(value, [ sparse_tensor_lib.SparseTensor.from_value(t) if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( t, name="component_%d" % i) for i, t in enumerate(nest.flatten(value)) ]) encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value)) output_classes = sparse.get_classes(value) output_shapes = nest.pack_sequence_as( value, [t.get_shape() for t in nest.flatten(value)]) output_types = nest.pack_sequence_as( value, [t.dtype for t in nest.flatten(value)]) return _OptionalImpl( gen_dataset_ops.optional_from_value(encoded_value, name=scope), output_shapes, output_types, output_classes)
def assertShapesEqual(self, a, b): for a, b in zip(nest.flatten(a), nest.flatten(b)): self.assertEqual(a.ndims, b.ndims) if a.ndims is None: continue for c, d in zip(a.as_list(), b.as_list()): self.assertEqual(c, d)
def testRoundTripConversion(self, value_fn): value = value_fn() s = structure.Structure.from_value(value) def maybe_stack_ta(v): if isinstance(v, tensor_array_ops.TensorArray): return v.stack() else: return v before = self.evaluate(maybe_stack_ta(value)) after = self.evaluate( maybe_stack_ta(s._from_tensor_list(s._to_tensor_list(value)))) flat_before = nest.flatten(before) flat_after = nest.flatten(after) for b, a in zip(flat_before, flat_after): if isinstance(b, sparse_tensor.SparseTensorValue): self.assertAllEqual(b.indices, a.indices) self.assertAllEqual(b.values, a.values) self.assertAllEqual(b.dense_shape, a.dense_shape) elif isinstance( b, (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)): self.assertRaggedEqual(b, a) else: self.assertAllEqual(b, a)
def assertDatasetsEqual(self, dataset1, dataset2): """Checks that datasets are equal. Supports both graph and eager mode.""" self.assertEqual(dataset1.output_types, dataset2.output_types) self.assertEqual(dataset1.output_classes, dataset2.output_classes) next1 = self.getNext(dataset1) next2 = self.getNext(dataset2) while True: try: op1 = self.evaluate(next1()) except errors.OutOfRangeError: with self.assertRaises(errors.OutOfRangeError): self.evaluate(next2()) break op2 = self.evaluate(next2()) op1 = nest.flatten(op1) op2 = nest.flatten(op2) assert len(op1) == len(op2) for i in range(len(op1)): if isinstance( op1[i], (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): self.assertSparseValuesEqual(op1[i], op2[i]) else: self.assertAllEqual(op1[i], op2[i])
def testToBatchedTensorList(self, value_fn, element_0_fn): batched_value = value_fn() s = structure.Structure.from_value(batched_value) batched_tensor_list = s._to_batched_tensor_list(batched_value) # The batch dimension is 2 for all of the test cases. # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT # tensors in which we store sparse tensors. for t in batched_tensor_list: if t.dtype != dtypes.variant: self.assertEqual(2, self.evaluate(array_ops.shape(t)[0])) # Test that the 0th element from the unbatched tensor is equal to the # expected value. expected_element_0 = self.evaluate(element_0_fn()) unbatched_s = s._unbatch() actual_element_0 = unbatched_s._from_tensor_list( [t[0] for t in batched_tensor_list]) for expected, actual in zip( nest.flatten(expected_element_0), nest.flatten(actual_element_0)): if sparse_tensor.is_sparse(expected): self.assertSparseValuesEqual(expected, actual) elif ragged_tensor.is_ragged(expected): self.assertRaggedEqual(expected, actual) else: self.assertAllEqual(expected, actual)
def get_next_as_optional(iterator): """Returns an `Optional` that contains the next value from the iterator. If `iterator` has reached the end of the sequence, the returned `Optional` will have no value. Args: iterator: A `tf.data.Iterator` object. Returns: An `Optional` object representing the next value from the iterator (if it has one) or no value. """ # pylint: disable=protected-access return optional_ops._OptionalImpl( gen_dataset_ops.iterator_get_next_as_optional( iterator._iterator_resource, output_types=nest.flatten( sparse.as_dense_types(iterator.output_types, iterator.output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(iterator.output_shapes, iterator.output_classes))), structure.Structure._from_legacy_structure(iterator.output_types, iterator.output_shapes, iterator.output_classes))
def get_next(self, name=None): """See `tf.data.Iterator.get_next`.""" self._get_next_call_count += 1 if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) flat_result = [] # TODO(priyag): This will fail if the input size (typically number of # batches) is not divisible by number of devices. # How do we handle that more gracefully / let the user know? for buffer_resource in self._buffering_resources: flat_ret = gen_dataset_ops.function_buffering_resource_get_next( buffer_resource, output_types=data_nest.flatten(sparse.as_dense_types( self.output_types, self.output_classes)), name=name) ret = sparse.deserialize_sparse_tensors( data_nest.pack_sequence_as(self.output_types, flat_ret), self.output_types, self.output_shapes, self.output_classes) for tensor, shape in zip( data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): if isinstance(tensor, ops.Tensor): tensor.set_shape(shape) flat_result.append(ret) return nest.pack_sequence_as(self._devices, flat_result)
def make_initializer(self, dataset, name=None): """Returns a `tf.Operation` that initializes this iterator on `dataset`. Args: dataset: A `Dataset` with compatible structure to this iterator. name: (Optional.) A name for the created operation. Returns: A `tf.Operation` that can be run to initialize this iterator on the given `dataset`. Raises: TypeError: If `dataset` and this iterator do not have a compatible element structure. """ with ops.name_scope(name, "make_initializer") as name: nest.assert_same_structure(self._output_types, dataset.output_types) nest.assert_same_structure(self._output_shapes, dataset.output_shapes) for iterator_dtype, dataset_dtype in zip( nest.flatten(self._output_types), nest.flatten(dataset.output_types)): if iterator_dtype != dataset_dtype: raise TypeError( "Expected output types %r but got dataset with output types %r." % (self._output_types, dataset.output_types)) for iterator_shape, dataset_shape in zip( nest.flatten(self._output_shapes), nest.flatten(dataset.output_shapes)): if not iterator_shape.is_compatible_with(dataset_shape): raise TypeError("Expected output shapes compatible with %r but got " "dataset with output shapes %r." % (self._output_shapes, dataset.output_shapes)) with ops.colocate_with(self._iterator_resource): return gen_dataset_ops.make_iterator( dataset._as_variant_tensor(), self._iterator_resource, name=name) # pylint: disable=protected-access
def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: classes = sparse.get_classes(expected) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), types, shapes, sparse.get_classes(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def _check_shape(*elements): flatten_tensors = nest.flatten(elements) flatten_shapes = nest.flatten(expected_shapes) checked_tensors = [with_shape(shape, tensor) for shape, tensor in zip(flatten_shapes, flatten_tensors)] return nest.pack_sequence_as(elements, checked_tensors)
def _as_variant_tensor(self): return gen_dataset_ops.random_dataset( seed=self._seed, seed2=self._seed2, output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes)), output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)))
def _as_variant_tensor(self): return gen_dataset_ops.set_stats_aggregator_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._stats_aggregator._resource, # pylint: disable=protected-access output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
def _as_variant_tensor(self): return self._op_function( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._tag, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
def output_shapes(self): ret = self._data_inputs[0].output_shapes for data_input in self._data_inputs[1:]: ret = nest.pack_sequence_as(ret, [ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( nest.flatten(ret), nest.flatten(data_input.output_shapes)) ]) return ret
def testFlattenDictOrder(self): """`flatten` orders dicts by key, including OrderedDicts.""" ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) plain = {"d": 3, "b": 1, "a": 0, "c": 2} ordered_flat = nest.flatten(ordered) plain_flat = nest.flatten(plain) self.assertEqual([0, 1, 2, 3], ordered_flat) self.assertEqual([0, 1, 2, 3], plain_flat)
def convert_legacy_structure(output_types, output_shapes, output_classes): """Returns a `Structure` that represents the given legacy structure. This method provides a way to convert from the existing `Dataset` and `Iterator` structure-related properties to a `Structure` object. A "legacy" structure is represented by the `tf.data.Dataset.output_types`, `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes` properties. TODO(b/110122868): Remove this function once `Structure` is used throughout `tf.data`. Args: output_types: A nested structure of `tf.DType` objects corresponding to each component of a structured value. output_shapes: A nested structure of `tf.TensorShape` objects corresponding to each component a structured value. output_classes: A nested structure of Python `type` objects corresponding to each component of a structured value. Returns: A `Structure`. Raises: TypeError: If a structure cannot be built from the arguments, because one of the component classes in `output_classes` is not supported. """ flat_types = nest.flatten(output_types) flat_shapes = nest.flatten(output_shapes) flat_classes = nest.flatten(output_classes) flat_ret = [] for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes, flat_classes): if isinstance(flat_class, Structure): flat_ret.append(flat_class) elif issubclass(flat_class, sparse_tensor_lib.SparseTensor): flat_ret.append(SparseTensorStructure(flat_type, flat_shape)) elif issubclass(flat_class, ops.Tensor): flat_ret.append(TensorStructure(flat_type, flat_shape)) elif issubclass(flat_class, tensor_array_ops.TensorArray): # We sneaked the dynamic_size and infer_shape into the legacy shape. flat_ret.append( TensorArrayStructure( flat_type, flat_shape[2:], dynamic_size=tensor_shape.dimension_value(flat_shape[0]), infer_shape=tensor_shape.dimension_value(flat_shape[1]))) else: # NOTE(mrry): Since legacy structures produced by iterators only # comprise Tensors, SparseTensors, and nests, we do not need to # support all structure types here. raise TypeError( "Could not build a structure for output class %r" % (flat_class,)) ret = nest.pack_sequence_as(output_classes, flat_ret) if isinstance(ret, Structure): return ret else: return NestedStructure(ret)
def from_string_handle(string_handle, output_types, output_shapes=None): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose between concrete iterators by feeding a value in a @{tf.Session.run} call. In that case, `string_handle` would a @{tf.placeholder}, and you would feed it with the value of @{tf.data.Iterator.string_handle} in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in each step as follows: ```python train_iterator = tf.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`) objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`) component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. Returns: An `Iterator`. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) nest.assert_same_structure(output_types, output_shapes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) iterator_resource = gen_dataset_ops.iterator_from_string_handle( string_handle, output_types=nest.flatten(sparse.unwrap_sparse_types(output_types)), output_shapes=nest.flatten(output_shapes)) return Iterator(iterator_resource, None, output_types, output_shapes)
def _as_variant_tensor(self): # pylint: disable=protected-access return gen_dataset_ops.directed_interleave_dataset( self._selector_input._as_variant_tensor(), [data_input._as_variant_tensor() for data_input in self._data_inputs], output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes)), output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)))
def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access return gen_dataset_ops.scan_dataset( input_t, nest.flatten(self._initial_state), self._scan_func.captured_inputs, f=self._scan_func, output_types=nest.flatten(self.output_types), output_shapes=nest.flatten(self.output_shapes))
def _as_variant_tensor(self): return gen_dataset_ops.dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes)), output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)))
def make_dataset_resource(self): return gen_dataset_ops.sloppy_interleave_dataset( self._input_dataset.make_dataset_resource(), self._map_func.captured_inputs, self._cycle_length, self._block_length, f=self._map_func, output_types=nest.flatten(self.output_types), output_shapes=nest.flatten(self.output_shapes))
def _as_variant_tensor(self): return gen_dataset_ops.slide_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access window_size=self._window_size, stride=self._stride, output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes)), output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)))
def _prefetch_fn(handle): """Prefetches one element from `input_iterator`.""" remote_iterator = iterator_ops.Iterator.from_string_handle( handle, self.output_types, self.output_shapes, self.output_classes) ret = remote_iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret))
def as_dataset(self, sample_batch_size=None, num_steps=None, num_parallel_calls=None, single_deterministic_pass=False): """Creates and returns a dataset that returns entries from the buffer. A single entry from the dataset is equivalent to one output from `get_next(sample_batch_size=sample_batch_size, num_steps=num_steps)`. Args: sample_batch_size: (Optional.) An optional batch_size to specify the number of items to return. If None (default), a single item is returned which matches the data_spec of this class (without a batch dimension). Otherwise, a batch of sample_batch_size items is returned, where each tensor in items will have its first dimension equal to sample_batch_size and the rest of the dimensions match the corresponding data_spec. num_steps: (Optional.) Optional way to specify that sub-episodes are desired. If None (default), a batch of single items is returned. Otherwise, a batch of sub-episodes is returned, where a sub-episode is a sequence of consecutive items in the replay_buffer. The returned tensors will have first dimension equal to sample_batch_size (if sample_batch_size is not None), subsequent dimension equal to num_steps, and remaining dimensions which match the data_spec of this class. num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, representing the number elements to process in parallel. If not specified, elements will be processed sequentially. single_deterministic_pass: Python boolean. If `True`, the dataset will return a single deterministic pass through its underlying data. **NOTE**: If the buffer is modified while a Dataset iterator is iterating over this data, the iterator may miss any new data or otherwise have subtly invalid data. Returns: A dataset of type tf.data.Dataset, elements of which are 2-tuples of: - An item or sequence of items or batch thereof - Auxiliary info for the items (i.e. ids, probs). Raises: NotImplementedError: If a non-default argument value is not supported. ValueError: If the data spec contains lists that must be converted to tuples. """ # data_tf.nest.flatten does not flatten python lists, nest.flatten does. if tf.nest.flatten(self._data_spec) != data_nest.flatten( self._data_spec): raise ValueError( 'Cannot perform gather; data spec contains lists and this conflicts ' 'with gathering operator. Convert any lists to tuples. ' 'For example, if your spec looks like [a, b, c], ' 'change it to (a, b, c). Spec structure is:\n {}'.format( tf.nest.map_structure(lambda spec: spec.dtype, self._data_spec))) if single_deterministic_pass: ds = self._single_deterministic_pass_dataset( sample_batch_size=sample_batch_size, num_steps=num_steps, num_parallel_calls=num_parallel_calls) else: ds = self._as_dataset(sample_batch_size=sample_batch_size, num_steps=num_steps, num_parallel_calls=num_parallel_calls) if self._stateful_dataset: options = tf.data.Options() if hasattr(options, 'experimental_allow_stateful'): options.experimental_allow_stateful = True ds = ds.with_options(options) return ds
def output_shapes(self): input_shapes = self._input_dataset.output_shapes return nest.pack_sequence_as(input_shapes, [ tensor_shape.vector(None).concatenate(s) for s in nest.flatten(self._input_dataset.output_shapes) ])
def _as_variant_tensor(self): return gen_dataset_ops.sequence_file_dataset( self._filenames, nest.flatten(self.output_types))
def testIndefiniteRepeatShapeInference(self): dataset = readers.make_tf_record_dataset( file_pattern=self.test_filenames, num_epochs=None, batch_size=32) for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)): self.assertEqual(32, shape[0])
def normalize_element(element, element_signature=None): """Normalizes a nested structure of element components. * Components matching `SparseTensorSpec` are converted to `SparseTensor`. * Components matching `RaggedTensorSpec` are converted to `RaggedTensor`. * Components matching `DatasetSpec` or `TensorArraySpec` are passed through. * `CompositeTensor` components are passed through. * All other components are converted to `Tensor`. Args: element: A nested structure of individual components. element_signature: (Optional.) A nested structure of `tf.DType` objects corresponding to each component of `element`. If specified, it will be used to set the exact type of output tensor when converting input components which are not tensors themselves (e.g. numpy arrays, native python types, etc.) Returns: A nested structure of `Tensor`, `Dataset`, `SparseTensor`, `RaggedTensor`, or `TensorArray` objects. """ normalized_components = [] if element_signature is None: components = nest.flatten(element) flattened_signature = [None] * len(components) pack_as = element else: flattened_signature = nest.flatten(element_signature) components = nest.flatten_up_to(element_signature, element) pack_as = element_signature with ops.name_scope("normalize_element"): # Imported here to avoid circular dependency. from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top for i, (t, spec) in enumerate(zip(components, flattened_signature)): try: if spec is None: spec = type_spec_from_value(t, use_fallback=False) except TypeError: # TypeError indicates it was not possible to compute a `TypeSpec` for # the value. As a fallback try converting the value to a tensor. normalized_components.append( ops.convert_to_tensor(t, name="component_%d" % i)) else: if isinstance(spec, sparse_tensor.SparseTensorSpec): normalized_components.append(sparse_tensor.SparseTensor.from_value(t)) elif isinstance(spec, ragged_tensor.RaggedTensorSpec): normalized_components.append( ragged_tensor.convert_to_tensor_or_ragged_tensor( t, name="component_%d" % i)) elif isinstance( spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)): normalized_components.append(t) elif isinstance(spec, NoneTensorSpec): normalized_components.append(NoneTensor()) elif isinstance(t, composite_tensor.CompositeTensor): normalized_components.append(t) else: dtype = getattr(spec, "dtype", None) normalized_components.append( ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) return nest.pack_sequence_as(pack_as, normalized_components)
def _testVariadicInputs(self, dataset_fn, input_datasets): self.assertEqual(nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs())
def from_value(value): flat_nested_structure = [ Structure.from_value(sub_value) for sub_value in nest.flatten(value) ] return NestedStructure(nest.pack_sequence_as(value, flat_nested_structure))
def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping defun for reduce_func.""" # Iteratively rerun the reduce function until reaching a fixed point on # `self._state_structure`. self._state_structure = self._init_func.output_structure state_types = self._init_func.output_types state_shapes = self._init_func.output_shapes state_classes = self._init_func.output_classes need_to_rerun = True while need_to_rerun: wrapped_func = structured_function.StructuredFunctionWrapper( reduce_func, self._transformation_name(), input_structure=(self._state_structure, input_dataset.element_spec), add_to_graph=False) # Extract and validate class information from the returned values. for new_state_class, state_class in zip( nest.flatten(wrapped_func.output_classes), nest.flatten(state_classes)): if not issubclass(new_state_class, state_class): raise TypeError( f"Invalid `reducer`. The output class of the " f"`reducer.reduce_func` {wrapped_func.output_classes}, " f"does not match the class of the reduce state " f"{self._state_classes}.") # Extract and validate type information from the returned values. for new_state_type, state_type in zip( nest.flatten(wrapped_func.output_types), nest.flatten(state_types)): if new_state_type != state_type: raise TypeError( f"Invalid `reducer`. The element types for the new state " f"{wrapped_func.output_types} do not match the element types " f"of the old state {self._init_func.output_types}.") # Extract shape information from the returned values. flat_state_shapes = nest.flatten(state_shapes) flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: state_shapes = nest.pack_sequence_as( self._init_func.output_shapes, weakened_state_shapes) self._state_structure = structure.convert_legacy_structure( state_types, state_shapes, state_classes) self._reduce_func = wrapped_func self._reduce_func.function.add_to_graph(ops.get_default_graph())
def testNestedStructure(self): components = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5.]), np.array([6., 7.])), np.array([8, 9, 10], dtype=np.int64)) dataset = dataset_ops.Dataset.from_tensors(components) self.assertEquals( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset.output_types) self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) dataset = dataset.shuffle(10, 10) self.assertEquals( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset.output_types) self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) dataset = dataset.repeat(-1) self.assertEquals( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset.output_types) self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) dataset = dataset.filter(lambda x, y, z: True) self.assertEquals( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset.output_types) self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) dataset = dataset.take(5) self.assertEquals( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset.output_types) self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1]))) self.assertEquals( ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)), dataset.output_types) self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes) dataset = dataset.flat_map(lambda x, y: dataset_ops.Dataset. from_tensors(((x[0], x[1]), (y[0], y[1])))) self.assertEquals( ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)), dataset.output_types) self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes) dataset = dataset.batch(32) self.assertEquals( ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)), dataset.output_types) self.assertEquals( (([None, 3], [None, 3]), ([None, 2], [None, 2])), nest.pack_sequence_as( dataset.output_shapes, [s.as_list() for s in nest.flatten(dataset.output_shapes)])) iterator = dataset.make_one_shot_iterator() (w, x), (y, z) = iterator.get_next() self.assertEquals(dtypes.int64, w.dtype) self.assertEquals(dtypes.int64, x.dtype) self.assertEquals(dtypes.float64, y.dtype) self.assertEquals(dtypes.float64, z.dtype) self.assertEquals([None, 3], w.shape.as_list()) self.assertEquals([None, 3], x.shape.as_list()) self.assertEquals([None, 2], y.shape.as_list()) self.assertEquals([None, 2], z.shape.as_list()) iterator = dataset.make_initializable_iterator() (w, x), (y, z) = iterator.get_next() self.assertEquals(dtypes.int64, w.dtype) self.assertEquals(dtypes.int64, x.dtype) self.assertEquals(dtypes.float64, y.dtype) self.assertEquals(dtypes.float64, z.dtype) self.assertEquals([None, 3], w.shape.as_list()) self.assertEquals([None, 3], x.shape.as_list()) self.assertEquals([None, 2], y.shape.as_list()) self.assertEquals([None, 2], z.shape.as_list()) # Define a separate set of components with matching leading # dimension for the from-slices constructor. components_for_slices = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5., 6.]), np.array([7., 8., 9.])), np.array([10, 11, 12], dtype=np.int64)) dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices) self.assertEquals( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset.output_types) self.assertEquals(([], ([], []), []), dataset.output_shapes)
def __init__(self, dataset, output_types, output_shapes=None, output_classes=None, allow_unsafe_cast=False): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: * `dataset.output_types` must be the same as `output_types` module nesting. * Each shape in `dataset.output_shapes` must be compatible with each shape in `output_shapes` (if given). Note: This helper permits "unsafe casts" for shapes, equivalent to using `tf.Tensor.set_shape()` where domain-specific knowledge is available. Args: dataset: A `Dataset` object. output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. output_classes: (Optional.) A nested structure of class types. If omitted, the class types will be inherited from `dataset`. allow_unsafe_cast: (Optional.) If `True`, the caller may switch the reported output types and shapes of the restructured dataset, e.g. to switch a sparse tensor represented as `tf.variant` to its user-visible type and shape. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ super(_RestructuredDataset, self).__init__() self._input_dataset = dataset if not allow_unsafe_cast: # Validate that the types are compatible. output_types = nest.map_structure(dtypes.as_dtype, output_types) flat_original_types = nest.flatten(dataset.output_types) flat_new_types = nest.flatten(output_types) if flat_original_types != flat_new_types: raise ValueError( "Dataset with output types %r cannot be restructured to have " "output types %r" % (dataset.output_types, output_types)) self._output_types = output_types if output_shapes is None: # Inherit shapes from the original `dataset`. self._output_shapes = nest.pack_sequence_as( output_types, nest.flatten(dataset.output_shapes)) else: if not allow_unsafe_cast: # Validate that the shapes are compatible. nest.assert_same_structure(output_types, output_shapes) flat_original_shapes = nest.flatten(dataset.output_shapes) flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) for original_shape, new_shape in zip(flat_original_shapes, flat_new_shapes): if not original_shape.is_compatible_with(new_shape): raise ValueError( "Dataset with output shapes %r cannot be restructured to have " "incompatible output shapes %r" % (dataset.output_shapes, output_shapes)) self._output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: # Inherit class types from the original `dataset`. self._output_classes = nest.pack_sequence_as( output_types, nest.flatten(dataset.output_classes)) else: self._output_classes = output_classes
def output_shapes(self): dim = self._batch_size if self._drop_remainder else None return nest.pack_sequence_as(self._output_shapes, [ tensor_shape.vector(dim).concatenate(s) for s in nest.flatten(self._output_shapes) ])
def __init__(self, input_dataset, initial_state, scan_func): """See `scan()` for details.""" super(_ScanDataset, self).__init__(input_dataset) self._input_dataset = input_dataset with ops.name_scope("initial_state"): # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. self._initial_state = nest.pack_sequence_as( initial_state, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( t, name="component_%d" % i) for i, t in enumerate(nest.flatten(initial_state)) ]) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_structure = structure.Structure.from_value( self._initial_state) # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( scan_func, self._transformation_name(), input_structure=structure.NestedStructure( (self._state_structure, input_dataset._element_structure)), # pylint: disable=protected-access add_to_graph=False) if not (isinstance(wrapped_func.output_types, collections.Sequence) and len(wrapped_func.output_types) == 2): raise TypeError( "The scan function must return a pair comprising the " "new state and the output value.") new_state_classes, self._output_classes = wrapped_func.output_classes # Extract and validate class information from the returned values. new_state_classes, output_classes = wrapped_func.output_classes old_state_classes = self._state_structure._to_legacy_output_classes( ) # pylint: disable=protected-access for new_state_class, old_state_class in zip( nest.flatten(new_state_classes), nest.flatten(old_state_classes)): if not issubclass(new_state_class, old_state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (old_state_classes, new_state_classes)) # Extract and validate type information from the returned values. new_state_types, output_types = wrapped_func.output_types old_state_types = self._state_structure._to_legacy_output_types() # pylint: disable=protected-access for new_state_type, old_state_type in zip( nest.flatten(new_state_types), nest.flatten(old_state_types)): if new_state_type != old_state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (old_state_types, new_state_types)) # Extract shape information from the returned values. new_state_shapes, output_shapes = wrapped_func.output_shapes old_state_shapes = self._state_structure._to_legacy_output_shapes() # pylint: disable=protected-access self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) flat_state_shapes = nest.flatten(old_state_shapes) flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: # TODO(b/110122868): Support a "most specific compatible structure" # method for combining structures, to avoid using legacy structures # in this method. self._state_structure = structure.convert_legacy_structure( old_state_types, nest.pack_sequence_as(old_state_shapes, weakened_state_shapes), old_state_classes) self._scan_func = wrapped_func self._scan_func.function.add_to_graph(ops.get_default_graph())
def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s representing the next element. In graph mode, you should typically call this method *once* and use its result as the input to another computation. A typical loop will then call @{tf.Session.run} on the result of that computation. The loop will terminate when the `Iterator.get_next()` operation raises @{tf.errors.OutOfRangeError}. The following skeleton shows how to use this method when building a training loop: ```python dataset = ... # A `tf.data.Dataset` object. iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() # Build a TensorFlow graph that does something with each element. loss = model_function(next_element) optimizer = ... # A `tf.train.Optimizer` object. train_op = optimizer.minimize(loss) with tf.Session() as sess: try: while True: sess.run(train_op) except tf.errors.OutOfRangeError: pass ``` NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g. when you are distributing different elements to multiple devices in a single step. However, a common pitfall arises when users call `Iterator.get_next()` in each iteration of their training loop. `Iterator.get_next()` adds ops to the graph, and executing each op allocates resources (including threads); as a consequence, invoking it in every iteration of a training loop causes slowdown and eventual resource exhaustion. To guard against this outcome, we log a warning when the number of uses crosses a fixed threshold of suspiciousness. Args: name: (Optional.) A name for the created operation. Returns: A nested structure of `tf.Tensor` objects. """ self._get_next_call_count += 1 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as( self._output_types, gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=nest.flatten( sparse.as_dense_types(self._output_types, self._output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)), name=name)), self._output_types, self._output_shapes, self._output_classes)
def _as_variant_tensor(self): return gen_dataset_ops.sql_dataset(self._driver_name, self._data_source_name, self._query, nest.flatten(self.output_types), nest.flatten(self.output_shapes))
def make_initializer(self, dataset, name=None): """Returns a `tf.Operation` that initializes this iterator on `dataset`. Args: dataset: A `Dataset` whose `element_spec` if compatible with this iterator. name: (Optional.) A name for the created operation. Returns: A `tf.Operation` that can be run to initialize this iterator on the given `dataset`. Raises: TypeError: If `dataset` and this iterator do not have a compatible `element_spec`. """ with ops.name_scope(name, "make_initializer") as name: # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due # to that creating a circular dependency. # pylint: disable=protected-access dataset_output_types = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_types( ), dataset.element_spec) dataset_output_shapes = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_shapes( ), dataset.element_spec) dataset_output_classes = nest.map_structure( lambda component_spec: component_spec. _to_legacy_output_classes(), dataset.element_spec) # pylint: enable=protected-access nest.assert_same_structure(self.output_types, dataset_output_types) nest.assert_same_structure(self.output_shapes, dataset_output_shapes) for iterator_class, dataset_class in zip( nest.flatten(self.output_classes), nest.flatten(dataset_output_classes)): if iterator_class is not dataset_class: raise TypeError( "Expected output classes %r but got dataset with output class %r." % (self.output_classes, dataset_output_classes)) for iterator_dtype, dataset_dtype in zip( nest.flatten(self.output_types), nest.flatten(dataset_output_types)): if iterator_dtype != dataset_dtype: raise TypeError( "Expected output types %r but got dataset with output types %r." % (self.output_types, dataset_output_types)) for iterator_shape, dataset_shape in zip( nest.flatten(self.output_shapes), nest.flatten(dataset_output_shapes)): if not iterator_shape.is_compatible_with(dataset_shape): raise TypeError( "Expected output shapes compatible with %r but got " "dataset with output shapes %r." % (self.output_shapes, dataset_output_shapes)) # TODO(b/169442955): Investigate the need for this colocation constraint. with ops.colocate_with(self._iterator_resource): # pylint: disable=protected-access return gen_dataset_ops.make_iterator(dataset._variant_tensor, self._iterator_resource, name=name)
def from_structure(output_types, output_shapes=None, shared_name=None, output_classes=None): """Creates a new, uninitialized `Iterator` with the given structure. This iterator-constructing method can be used to create an iterator that is reusable with many different datasets. The returned iterator is not bound to a particular dataset, and it has no `initializer`. To initialize the iterator, run the operation returned by `Iterator.make_initializer(dataset)`. The following is an example ```python iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) dataset_range = Dataset.range(10) range_initializer = iterator.make_initializer(dataset_range) dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) evens_initializer = iterator.make_initializer(dataset_evens) # Define a model based on the iterator; in this example, the model_fn # is expected to take scalar tf.int64 Tensors as input (see # the definition of 'iterator' above). prediction, loss = model_fn(iterator.get_next()) # Train for `num_epochs`, where for each epoch, we first iterate over # dataset_range, and then iterate over dataset_evens. for _ in range(num_epochs): # Initialize the iterator to `dataset_range` sess.run(range_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break # Initialize the iterator to `dataset_evens` sess.run(evens_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break ``` Args: output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. shared_name: (Optional.) If non-empty, this iterator will be shared under the given name across multiple sessions that share the same devices (e.g. when using a remote server). output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. Raises: TypeError: If the structures of `output_shapes` and `output_types` are not the same. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) if shared_name is None: shared_name = "" if compat.forward_compatible(2018, 8, 3): if not ops.get_default_graph()._graph_device_function_stack: # pylint: disable=protected-access with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) else: iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) else: iterator_resource = gen_dataset_ops.iterator( container="", shared_name=shared_name, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def testMakeCSVDataset_withShuffle(self): record_defaults = [ constant_op.constant([], dtypes.int32), constant_op.constant([], dtypes.int64), constant_op.constant([], dtypes.float32), constant_op.constant([], dtypes.float64), constant_op.constant([], dtypes.string) ] def str_series(st): return ",".join(str(i) for i in range(st, st + 5)) column_names = ["col%d" % i for i in range(5)] inputs = [ [",".join(x for x in column_names) ] + [str_series(5 * i) for i in range(15)], [",".join(x for x in column_names)] + [str_series(5 * i) for i in range(15, 20)], ] filenames = self._setup_files(inputs) total_records = 20 for batch_size in [1, 2]: with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Test that shuffling with the same seed produces the same result dataset1 = self._make_csv_dataset( filenames, column_defaults=record_defaults, column_names=column_names, batch_size=batch_size, header=True, shuffle=True, shuffle_seed=5, num_epochs=2, ) dataset2 = self._make_csv_dataset( filenames, column_defaults=record_defaults, column_names=column_names, batch_size=batch_size, header=True, shuffle=True, shuffle_seed=5, num_epochs=2, ) outputs1 = dataset1.make_one_shot_iterator().get_next() outputs2 = dataset2.make_one_shot_iterator().get_next() for _ in range(total_records // batch_size): batch1 = nest.flatten(sess.run(outputs1)) batch2 = nest.flatten(sess.run(outputs2)) for i in range(len(batch1)): self.assertAllEqual(batch1[i], batch2[i]) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Test that shuffling with a different seed produces different results dataset1 = self._make_csv_dataset( filenames, column_defaults=record_defaults, column_names=column_names, batch_size=batch_size, header=True, shuffle=True, shuffle_seed=5, num_epochs=2, ) dataset2 = self._make_csv_dataset( filenames, column_defaults=record_defaults, column_names=column_names, batch_size=batch_size, header=True, shuffle=True, shuffle_seed=6, num_epochs=2, ) outputs1 = dataset1.make_one_shot_iterator().get_next() outputs2 = dataset2.make_one_shot_iterator().get_next() all_equal = False for _ in range(total_records // batch_size): batch1 = nest.flatten(sess.run(outputs1)) batch2 = nest.flatten(sess.run(outputs2)) for i in range(len(batch1)): all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) self.assertFalse(all_equal)
def from_string_handle(string_handle, output_types, output_shapes=None, output_classes=None): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose between concrete iterators by feeding a value in a @{tf.Session.run} call. In that case, `string_handle` would a @{tf.placeholder}, and you would feed it with the value of @{tf.data.Iterator.string_handle} in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in each step as follows: ```python train_iterator = tf.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) if compat.forward_compatible(2018, 8, 3): if not ops.get_default_graph()._graph_device_function_stack: # pylint: disable=protected-access with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle( string_handle, output_types=nest.flatten( sparse.as_dense_types(output_types, output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def __init__(self, dataset, devices, max_buffer_size=1, prefetch_buffer_size=1, source_device="/cpu:0"): """Constructs a MultiDeviceIterator. Args: dataset: The input dataset to be iterated over. devices: The list of devices to fetch data to. max_buffer_size: Maximum size of the host side per device buffer to keep. prefetch_buffer_size: if > 1, then we setup a buffer on each device to prefetch into. source_device: The host device to place the `dataset` on. """ self._dataset = dataset self._devices = devices self._source_device = source_device self._source_device_tensor = ops.convert_to_tensor(source_device) self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._dataset.output_shapes, self._dataset.output_classes)) self._flat_output_types = nest.flatten( sparse.as_dense_types(self._dataset.output_types, self._dataset.output_classes)) # Create the MultiDeviceIterator. with ops.device(self._source_device): self._multi_device_iterator_resource = ( gen_dataset_ops.multi_device_iterator( devices=self._devices, shared_name="", container="", output_types=self._flat_output_types, output_shapes=self._flat_output_shapes)) # The incarnation ID is used to ensure consistency between the per-device # iterators and the multi-device iterator. self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( self._dataset._as_variant_tensor(), # pylint: disable=protected-access self._multi_device_iterator_resource, max_buffer_size=max_buffer_size) # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to # initialize the device side of the pipeline. This would allow the # MultiDeviceIterator to choose, for example, to move some transformations # into the device side from its input. It might be useful in rewriting. # Create the per device iterators. self._device_iterators = [] i = 0 for device in self._devices: ds = _PerDeviceGenerator( i, self._multi_device_iterator_resource, self._incarnation_id, self._source_device_tensor, device, self._dataset.output_shapes, self._dataset.output_types, self._dataset.output_classes) if prefetch_buffer_size > 0: ds = ds.prefetch(prefetch_buffer_size) with ops.device(device): self._device_iterators.append(ds.make_initializable_iterator()) i += 1 device_iterator_initializers = [ iterator.initializer for iterator in self._device_iterators ] self._initializer = control_flow_ops.group(*device_iterator_initializers)
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, target_device, output_shapes, output_types, output_classes): self._target_device = target_device self._output_types = output_types self._output_shapes = output_shapes self._output_classes = output_classes self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)) self._flat_output_types = nest.flatten( sparse.as_dense_types(self._output_types, self._output_classes)) multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) @function.Defun() def _init_func(): return multi_device_iterator_string_handle @function.Defun() def _remote_init_func(): return functional_ops.remote_call( target=source_device, args=_init_func.captured_inputs, Tout=[dtypes.string], f=_init_func) self._init_func = _remote_init_func self._init_captured_args = _remote_init_func.captured_inputs @function.Defun(dtypes.string) def _next_func(string_handle): multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes)) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) @function.Defun(dtypes.string) def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + _next_func.captured_inputs, Tout=self._flat_output_types, f=_next_func) self._next_func = _remote_next_func self._next_captured_args = _remote_next_func.captured_inputs @function.Defun(dtypes.string) def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) @function.Defun(dtypes.string) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + _finalize_func.captured_inputs, Tout=[dtypes.int64], f=_finalize_func) self._finalize_func = _remote_finalize_func self._finalize_captured_args = _remote_finalize_func.captured_inputs
def _inputs(self): # Apparently here TF is happy with a list return nest.flatten(self._input_datasets)
def output_shapes(self): return nest.pack_sequence_as(self._output_shapes, [ tensor_shape.vector(tensor_util.constant_value( self._batch_size)).concatenate(s) for s in nest.flatten(self._output_shapes) ])
def converted_call(f, owner, options, *args, **kwargs): """Compiles a function call inline. For internal use only.""" if options.verbose: logging.info('Converted call: {}; owner: {}'.format(f, owner)) if owner is not None: if not isinstance(f, str): raise ValueError( 'When owner is specified, the function name must be specified as' ' a string: {}'.format(f)) # Special case when the owner is a 'super' object. In that case lookups of # dynamic attributes won't work. See # inspect_utils.SuperWrapperForDynamicAttrs. if isinstance(owner, super): owner = inspect_utils.SuperWrapperForDynamicAttrs(owner) f = getattr(owner, f) # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. if not options.force_conversion and conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) if inspect_utils.isbuiltin(f): return py_builtins.overload_of(f)(*args, **kwargs) # internal_convert_user_code is for example turned off when issuing a dynamic # call conversion from generated code while in nonrecursive mode. In that # case we evidently don't want to recurse, but we still have to convert # things like builtins. if not options.internal_convert_user_code: return f(*args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f arg_map_target = f f_class = inspect_utils.getmethodclass(f) if f_class is not None: # If this is a method call, it may or may not include self. # # Example when self is included: # converted_call(to_graph(foo.bar), foo) # # Example when self is not included: # super(...).foo(args) # if owner is not None and (not args or args[0] is not owner): effective_args = (owner, ) + args else: effective_args = args partial_types = (f_class, ) else: effective_args = args partial_types = () elif tf_inspect.isclass(f): # Constructors target_entity = f arg_map_target = f.__init__ effective_args = args partial_types = () elif hasattr(f, '__call__') and hasattr(f, '__class__'): # Callable objects target_entity = f.__call__ arg_map_target = f.__call__ effective_args = (f, ) + args partial_types = (f.__class__, ) else: NotImplementedError('unknown callable type "%s"' % type(f)) arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) arg_types = {} for name, arg in arg_values.items(): arg_class = arg.__class__ arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied # before the method is bound. if not partial_types: if 'self' in arg_values: if tf_inspect.isclass(arg_values['self'].__class__): partial_types = (arg_values['self'].__class__, ) elif 'cls' in arg_values: if tf_inspect.isclass(arg_values['cls']): partial_types = (arg_values['cls'], ) converted_f = to_graph(target_entity, recursive=options.recursive, verbose=options.verbose, arg_values=arg_values, arg_types=arg_types, partial_types=partial_types, strip_decorators=options.strip_decorators, optional_features=options.optional_features) result = converted_f(*effective_args, **kwargs) # When converting a function, we write a tmp file and import it as a module. # This leaks the module's closure. Once we've executed the converted_f module # and there is no more code left to be executed, we can clean up the module. # TODO(mdan): Look into workarounds that don't suffer from refcount leaks. # Possibly attach the closure as a regular closure cell, instead of relying on # module globals. # If there are callables in the result, they will fail to find their closure # when called, so only delete module if all returned types are not callable. flat_results = nest.flatten(result) if all(map(_is_not_callable, flat_results)): del sys.modules[converted_f.__module__] return result
def __init__(self, dataset): """Creates a new iterator over the given dataset. For example: ```python dataset = tf.data.Dataset.range(4) for x in Iterator(dataset): print(x) ``` Tensors produced will be placed on the device on which this iterator object was created. Args: dataset: A `tf.data.Dataset` object. Raises: RuntimeError: When invoked without eager execution enabled. """ if not context.in_eager_mode(): raise RuntimeError( "{} objects can only be used when eager execution is enabled, use " "tf.data.Dataset.make_iterator or " "tf.data.Dataset.make_one_shot_iterator for graph construction". format(type(self))) with ops.device("/device:CPU:0"): ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access self._output_classes = dataset.output_classes self._output_types = dataset.output_types self._output_shapes = dataset.output_shapes self._flat_output_types = nest.flatten( sparse.as_dense_types(self._output_types, self._output_classes)) self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)) self._resource = gen_dataset_ops.iterator( shared_name="", container=_generate_shared_name("eageriterator"), output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) # Delete the resource when this object is deleted self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="/device:CPU:0") self._device = context.context().device_name self._buffer_resource_handle = None if not context.context().device_spec.device_type: is_remote_device = False else: is_remote_device = context.context().device_spec.device_type != "CPU" if is_remote_device: with ops.device("/device:CPU:0"): iter_string_handle = gen_dataset_ops.iterator_to_string_handle( self._resource) @function.Defun(dtypes.string) def remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( h, self._output_types, self._output_shapes) return remote_iterator.get_next() remote_fn.add_to_graph(None) target = constant_op.constant("/device:CPU:0") with ops.device(self._device): self._buffer_resource_handle = prefetching_ops.function_buffering_resource( string_arg=iter_string_handle, f=remote_fn, target_device=target, buffer_size=10, thread_pool_size=1, container="", shared_name=_generate_shared_name("function_buffer_resource")) self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._buffer_resource_handle, handle_device=self._device)
def _flat_shapes(dataset): return nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
def testLegacyStructureAPI(self): components = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5.]), np.array([6., 7.])), np.array([8, 9, 10], dtype=np.int64)) dataset = dataset_ops.Dataset.from_tensors(components) self.assertEqual( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset_ops.get_legacy_output_types(dataset)) self.assertEqual(([3], ([2], [2]), [3]), dataset_ops.get_legacy_output_shapes(dataset)) dataset = dataset.shuffle(10, 10) self.assertEqual( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset_ops.get_legacy_output_types(dataset)) self.assertEqual(([3], ([2], [2]), [3]), dataset_ops.get_legacy_output_shapes(dataset)) dataset = dataset.repeat(-1) self.assertEqual( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset_ops.get_legacy_output_types(dataset)) self.assertEqual(([3], ([2], [2]), [3]), dataset_ops.get_legacy_output_shapes(dataset)) dataset = dataset.filter(lambda x, y, z: True) self.assertEqual( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset_ops.get_legacy_output_types(dataset)) self.assertEqual(([3], ([2], [2]), [3]), dataset_ops.get_legacy_output_shapes(dataset)) dataset = dataset.take(5) self.assertEqual( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset_ops.get_legacy_output_types(dataset)) self.assertEqual(([3], ([2], [2]), [3]), dataset_ops.get_legacy_output_shapes(dataset)) dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1]))) self.assertEqual( ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)), dataset_ops.get_legacy_output_types(dataset)) self.assertEqual((([3], [3]), ([2], [2])), dataset_ops.get_legacy_output_shapes(dataset)) dataset = dataset.flat_map(lambda x, y: dataset_ops.Dataset. from_tensors(((x[0], x[1]), (y[0], y[1])))) self.assertEqual( ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)), dataset_ops.get_legacy_output_types(dataset)) self.assertEqual((([3], [3]), ([2], [2])), dataset_ops.get_legacy_output_shapes(dataset)) dataset = dataset.batch(32) self.assertEqual( ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)), dataset_ops.get_legacy_output_types(dataset)) dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset) self.assertEqual( (([None, 3], [None, 3]), ([None, 2], [None, 2])), nest.pack_sequence_as( dataset_output_shapes, [s.as_list() for s in nest.flatten(dataset_output_shapes)])) # Define a separate set of components with matching leading # dimension for the from-slices constructor. components_for_slices = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5., 6.]), np.array([7., 8., 9.])), np.array([10, 11, 12], dtype=np.int64)) dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices) self.assertEqual( (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64), dataset_ops.get_legacy_output_types(dataset)) self.assertEqual(([], ([], []), []), dataset_ops.get_legacy_output_shapes(dataset))
def __init__(self, input_dataset, target_device, source_device="/cpu:0"): """Constructs a _CopyToDeviceDataset. Args: input_dataset: `Dataset` to be copied target_device: The name of the device to which elements would be copied. source_device: Device where input_dataset would be placed. """ super(_CopyToDeviceDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._target_device = target_device spec = framework_device.DeviceSpec().from_string(self._target_device) self._is_gpu_target = (spec.device_type == "GPU") self._source_device_string = source_device self._source_device = ops.convert_to_tensor(source_device) self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._input_dataset.output_shapes, self._input_dataset.output_classes)) self._flat_output_types = nest.flatten( sparse.as_dense_types(self._input_dataset.output_types, self._input_dataset.output_classes)) @function.Defun() def _init_func(): """Creates an iterator for the input dataset. Returns: A `string` tensor that encapsulates the iterator created. """ # pylint: disable=protected-access ds_variant = self._input_dataset._as_variant_tensor() resource = core_gen_dataset_ops.anonymous_iterator( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) with ops.control_dependencies( [core_gen_dataset_ops.make_iterator(ds_variant, resource)]): return core_gen_dataset_ops.iterator_to_string_handle(resource) @function.Defun() def _remote_init_func(): return functional_ops.remote_call( target=self._source_device, args=_init_func.captured_inputs, Tout=[dtypes.string], f=_init_func) self._init_func = _remote_init_func self._init_captured_args = _remote_init_func.captured_inputs @function.Defun(dtypes.string) def _next_func(string_handle): """Calls get_next for created iterator. Args: string_handle: An iterator string handle created by _init_func Returns: The elements generated from `input_dataset` """ with ops.device(self._source_device_string): iterator = iterator_ops.Iterator.from_string_handle( string_handle, self.output_types, self.output_shapes, self.output_classes) ret = iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret)) @function.Defun(dtypes.string) def _remote_next_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + _next_func.captured_inputs, Tout=self._flat_output_types, f=_next_func) self._next_func = _remote_next_func self._next_captured_args = _remote_next_func.captured_inputs @function.Defun(dtypes.string) def _finalize_func(string_handle): """Destroys the iterator resource created. Args: string_handle: An iterator string handle created by _init_func Returns: Tensor constant 0 """ iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) with ops.control_dependencies([ resource_variable_ops.destroy_resource_op( iterator_resource, ignore_lookup_error=True)]): return array_ops.constant(0, dtypes.int64) @function.Defun(dtypes.string) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + _finalize_func.captured_inputs, Tout=[dtypes.int64], f=_finalize_func) self._finalize_func = _remote_finalize_func self._finalize_captured_args = _remote_finalize_func.captured_inputs g = ops.get_default_graph() _remote_init_func.add_to_graph(g) _remote_next_func.add_to_graph(g) _remote_finalize_func.add_to_graph(g)
def converted_call(f, owner, options, *args, **kwargs): """Compiles a function call inline. For internal use only.""" logging.vlog(logging.DEBUG, 'Converted call: %s; owner: %s', f, owner) if owner is not None: if not isinstance(f, str): raise ValueError( 'When owner is specified, the function name must be specified as' ' a string: {}'.format(f)) # Special case when the owner is a 'super' object. In that case lookups of # dynamic attributes won't work. See # inspect_utils.SuperWrapperForDynamicAttrs. if isinstance(owner, super): owner = inspect_utils.SuperWrapperForDynamicAttrs(owner) f = getattr(owner, f) if inspect_utils.isbuiltin(f): return py_builtins.overload_of(f)(*args, **kwargs) # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. if not options.force_conversion and conversion.is_whitelisted_for_graph(f): # TODO(mdan): This may be inconsistent in certain situations. # If the function had already been annotated with @tf.function, it # may be bound to the incorrect object. It's unclear if those situations # are possible, but if they happen, we need to check if f is bound # to a shim like WeakrefSelf and unpack it. # Args typically include `self`, as required by the conversion process. # When conversion is skipped, `self` is not necessary, because the # original bound method is being executed. This code removes it. if tf_inspect.ismethod(f) and args: f_self = inspect_utils.getmethodself(f) if args[0] is f_self: args = args[1:] return f(*args, **kwargs) # internal_convert_user_code is for example turned off when issuing a dynamic # call conversion from generated code while in nonrecursive mode. In that # case we evidently don't want to recurse, but we still have to convert # things like builtins. if not options.internal_convert_user_code: return f(*args, **kwargs) # Unwrap functools.partial objects # TODO(mdan): Consider sharing unwrapping logic with tf_inspect. while isinstance(f, functools.partial): args = f.args + args new_kwargs = {} if f.keywords is not None: new_kwargs.update(f.keywords) new_kwargs.update(kwargs) kwargs = new_kwargs f = f.func if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f arg_map_target = f f_self = inspect_utils.getmethodself(f) # TODO(b/119246461): This may be more elegantly handled using __get__? if f_self is not None: # If this is a method call, it may or may not include self. # # Example when self is included: # converted_call(to_graph(foo.bar), foo) # # Example when self is not included: # super(...).foo(args) # if owner is not None and (not args or args[0] is not owner): effective_args = (owner, ) + args else: # When the owner is not specified, use the result of # inspect_utils.getmethodclass. # TODO(b/119246461): Make sure an owner is always specified. if not args or args[0] is not f_self: effective_args = (f_self, ) + args else: effective_args = (f_self, ) + args[1:] partial_types = (f_self, ) else: effective_args = args partial_types = () elif tf_inspect.isclass(f): # Constructors target_entity = f arg_map_target = f.__init__ effective_args = args partial_types = () elif hasattr(f, '__call__') and hasattr(f, '__class__'): # Callable objects target_entity = f.__call__ arg_map_target = f.__call__ effective_args = (f, ) + args partial_types = (f.__class__, ) else: NotImplementedError('unknown callable type "%s"' % type(f)) arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) arg_types = {} for name, arg in arg_values.items(): arg_class = arg.__class__ arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied # before the method is bound. if not partial_types: if 'self' in arg_values: if tf_inspect.isclass(arg_values['self'].__class__): partial_types = (arg_values['self'].__class__, ) elif 'cls' in arg_values: if tf_inspect.isclass(arg_values['cls']): partial_types = (arg_values['cls'], ) converted_f = to_graph( target_entity, recursive=options.recursive, arg_values=arg_values, arg_types=arg_types, experimental_optional_features=options.optional_features, experimental_strip_decorators=options.strip_decorators, experimental_verbose=options.verbose, experimental_partial_types=partial_types) result = converted_f(*effective_args, **kwargs) # The converted function's closure is simply inserted into the function's # module __dict__. Since modules are permanently cached, that results in # leaking the entire closure. # Normally, it's not safe to delete the module because that may release said # closure as well. However, in the case of converted_call we are certain the # function will not be executed again, so the closure should no longer be # needed so long as the function doesn't return any executable code. # TODO(mdan): Attach the closure properly, using cells. if all(map(_is_not_callable, nest.flatten(result))): del sys.modules[converted_f.__module__] return result
def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual( nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs())