class StructureTest(test.TestCase, parameterized.TestCase): # pylint disable=protected-access @parameterized.parameters( (constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32], [[]]), (sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), structure.SparseTensorStructure, [dtypes.variant], [[3]]), ((constant_op.constant(37.0), constant_op.constant([1, 2, 3])), structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), ({ "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3] ]), ({ "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, structure.NestedStructure, [dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]])) def testFlatStructure(self, value, expected_structure, expected_types, expected_shapes): s = structure.Structure.from_value(value) self.assertIsInstance(s, expected_structure) self.assertEqual(expected_types, s._flat_types) self.assertEqual(expected_shapes, s._flat_shapes) @parameterized.parameters( (constant_op.constant(37.0), [ constant_op.constant(38.0), array_ops.placeholder(dtypes.float32), variables.Variable(100.0), 42.0, np.array(42.0, dtype=np.float32) ], [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]), (sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), [ sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), array_ops.sparse_placeholder(dtype=dtypes.int32), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) ], [ constant_op.constant(37, shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None, None]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) ]), ({ "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6]) }], [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6, 7]) }, { "a": constant_op.constant(15), "b": constant_op.constant([4, 5, 6]) }, { "a": constant_op.constant(15), "b": sparse_tensor.SparseTensor( indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), ) def testIsCompatibleWithStructure(self, original_value, compatible_values, incompatible_values): s = structure.Structure.from_value(original_value) for compatible_value in compatible_values: self.assertTrue( s.is_compatible_with( structure.Structure.from_value(compatible_value))) for incompatible_value in incompatible_values: self.assertFalse( s.is_compatible_with( structure.Structure.from_value(incompatible_value))) # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they # will be executed before the (eager- or graph-mode) test environment has been # set up. # pylint: disable=g-long-lambda @parameterized.parameters( (lambda: constant_op.constant(37.0), ), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, ), (lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor. SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, ), ) def testRoundTripConversion(self, value_fn): value = value_fn() s = structure.Structure.from_value(value) before = self.evaluate(value) after = self.evaluate(s._from_tensor_list(s._to_tensor_list(value))) flat_before = nest.flatten(before) flat_after = nest.flatten(after) for b, a in zip(flat_before, flat_after): if isinstance(b, sparse_tensor.SparseTensorValue): self.assertAllEqual(b.indices, a.indices) self.assertAllEqual(b.values, a.values) self.assertAllEqual(b.dense_shape, a.dense_shape) else: self.assertAllEqual(b, a) # pylint: enable=g-long-lambda def testIncompatibleStructure(self): # Define three mutually incompatible values/structures, and assert that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_tensor = constant_op.constant(42.0) s_tensor = structure.Structure.from_value(value_tensor) flat_tensor = s_tensor._to_tensor_list(value_tensor) value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor) flat_sparse_tensor = s_sparse_tensor._to_tensor_list( value_sparse_tensor) value_nest = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_nest = structure.Structure.from_value(value_nest) flat_nest = s_nest._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Value \{.*\} is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_tensor) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*float32.* and shape \(\)"): s_tensor._from_tensor_list(flat_sparse_tensor) with self.assertRaisesRegexp( ValueError, "TensorStructure corresponds to a single tf.Tensor."): s_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_sparse_tensor) def testIncompatibleNestedStructure(self): # Define three mutually incompatible nested values/structures, and assert # that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_0 = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_0 = structure.Structure.from_value(value_0) flat_s_0 = s_0._to_tensor_list(value_0) # `value_1` has compatible nested structure with `value_0`, but different # classes. value_1 = { "a": constant_op.constant(37.0), "b": sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) } s_1 = structure.Structure.from_value(value_1) flat_s_1 = s_1._to_tensor_list(value_1) # `value_2` has incompatible nested structure with `value_0` and `value_1`. value_2 = { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) } s_2 = structure.Structure.from_value(value_2) flat_s_2 = s_2._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure"): s_0._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*SparseTensorStructure"): s_1._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp # needs to account for "a" coming before or after "b". It might be worth # adding a deterministic repr for these error messages (among other # improvements). with self.assertRaisesRegexp( ValueError, "Tensor.*Tensor.* not compatible with the nested structure " ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)" ): s_2._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* " "not compatible with the nested structure .*" "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)" ): s_2._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"): s_0._from_tensor_list(flat_s_1) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_0._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_1._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_1._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_1) @parameterized.named_parameters( ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor, structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", dtypes.int32, tensor_shape.matrix( 2, 2), sparse_tensor.SparseTensor, structure.SparseTensorStructure(dtypes.int32, [2, 2])), ("Nest", { "a": dtypes.float32, "b": (dtypes.int32, dtypes.string) }, { "a": tensor_shape.scalar(), "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar()) }, { "a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) })), ) def testFromLegacyStructure(self, output_types, output_shapes, output_classes, expected_structure): actual_structure = structure.Structure._from_legacy_structure( output_types, output_shapes, output_classes) self.assertTrue( expected_structure.is_compatible_with(actual_structure)) self.assertTrue( actual_structure.is_compatible_with(expected_structure))
def _element_structure(self): return structure.NestedStructure( tuple([structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
def _element_structure(self): return structure.NestedStructure( (structure.TensorStructure(dtypes.string, []), structure.TensorStructure(dtypes.string, [])))
class IteratorTest(test.TestCase, parameterized.TestCase): @test_util.run_deprecated_v1 def testNoGradients(self): component = constant_op.constant([1.]) side = constant_op.constant(0.) add = lambda x: x + side dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) value = dataset_ops.make_one_shot_iterator(dataset).get_next() self.assertIsNone(gradients_impl.gradients(value, component)[0]) self.assertIsNone(gradients_impl.gradients(value, side)[0]) self.assertIsNone(gradients_impl.gradients(value, [component, side])[0]) @test_util.run_deprecated_v1 def testCapturingStateInOneShotRaisesException(self): var = variables.Variable(37.0, name="myvar") dataset = ( dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0]) .map(lambda x: x + var)) with self.assertRaisesRegexp( ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support " "datasets that capture stateful objects.+myvar"): dataset_ops.make_one_shot_iterator(dataset) @test_util.run_deprecated_v1 def testOneShotIterator(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(14)) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @test_util.run_deprecated_v1 def testOneShotIteratorCaptureByValue(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) tensor_components = tuple([ops.convert_to_tensor(c) for c in components]) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(tensor_components) .map(_map_fn).repeat(14)) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testOneShotIteratorInsideContainer(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def within_container(): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensor_slices(components) .map(_map_fn).repeat(14)) return iterator.get_next() server = server_lib.Server.create_local_server() # Create two iterators within unique containers, and run them to # make sure that the resources aren't shared. # # The test below would fail if cname were the same across both # sessions. for j in range(2): with session.Session(server.target) as sess: cname = "iteration%d" % j with ops.container(cname): get_next = within_container() for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @test_util.run_deprecated_v1 def testOneShotIteratorNonBlocking(self): dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) iterator = dataset_ops.make_one_shot_iterator(dataset) next_element = iterator.get_next() # Create a session with a single thread to ensure that the # one-shot iterator initializer does not deadlock. config = config_pb2.ConfigProto( inter_op_parallelism_threads=1, use_per_session_threads=True) with session.Session(config=config) as sess: self.assertAllEqual([1, 4, 9], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) # Test with multiple threads invoking the one-shot iterator concurrently. with session.Session(config=config) as sess: results = [] def consumer_thread(): try: results.append(sess.run(next_element)) except errors.OutOfRangeError: results.append(None) num_threads = 8 threads = [ self.checkedThread(consumer_thread) for _ in range(num_threads) ] for t in threads: t.start() for t in threads: t.join() self.assertEqual(num_threads, len(results)) self.assertEqual(num_threads - 1, len([None for r in results if r is None])) self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) @test_util.run_deprecated_v1 def testOneShotIteratorInitializerFails(self): # Define a dataset whose initialization will always fail. dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) iterator = dataset_ops.make_one_shot_iterator(dataset) next_element = iterator.get_next() with self.cached_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(next_element) # Test that subsequent attempts to use the iterator also fail. with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(next_element) with self.cached_session() as sess: def consumer_thread(): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(next_element) num_threads = 8 threads = [ self.checkedThread(consumer_thread) for _ in range(num_threads) ] for t in threads: t.start() for t in threads: t.join() def testSimpleSharedResource(self): components = (np.array(1, dtype=np.int64), np.array([1, 2, 3], dtype=np.int64), np.array(37.0, dtype=np.float64)) server = server_lib.Server.create_local_server() # Create two non-overlapping sessions that share the same iterator # resource on the same server, and verify that an action of the # first session (initializing the iterator) is visible in the # second session. with ops.Graph().as_default(): iterator = ( dataset_ops.Dataset.from_tensors(components) .map(lambda x, y, z: (x, y, z)).make_initializable_iterator( shared_name="shared_iterator")) init_op = iterator.initializer get_next = iterator.get_next() with session.Session(server.target) as sess: sess.run(init_op) results = sess.run(get_next) for component, result_component in zip(components, results): self.assertAllEqual(component, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Re-initialize the iterator in the first session. sess.run(init_op) with ops.Graph().as_default(): # Re-define the iterator manually, without defining any of the # functions in this graph, to ensure that we are not # accidentally redefining functions with the same names in the # new graph. iterator = iterator_ops.Iterator.from_structure( shared_name="shared_iterator", output_types=(dtypes.int64, dtypes.int64, dtypes.float64), output_shapes=([], [3], [])) get_next = iterator.get_next() with session.Session(server.target) as sess: # Use the iterator without re-initializing in the second session. results = sess.run(get_next) for component, result_component in zip(components, results): self.assertAllEqual(component, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @test_util.run_deprecated_v1 def testNotInitializedError(self): components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensors(components)) get_next = iterator.get_next() with self.cached_session() as sess: with self.assertRaisesRegexp(errors.FailedPreconditionError, "iterator has not been initialized"): sess.run(get_next) @test_util.run_deprecated_v1 def testReinitializableIterator(self): dataset_3 = dataset_ops.Dataset.from_tensors( constant_op.constant([1, 2, 3])) dataset_4 = dataset_ops.Dataset.from_tensors( constant_op.constant([4, 5, 6, 7])) iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types, [None]) dataset_3_init_op = iterator.make_initializer(dataset_3) dataset_4_init_op = iterator.make_initializer(dataset_4) get_next = iterator.get_next() self.assertEqual(dataset_3.output_types, iterator.output_types) self.assertEqual(dataset_4.output_types, iterator.output_types) self.assertEqual([None], iterator.output_shapes.as_list()) with self.cached_session() as sess: # The iterator is initially uninitialized. with self.assertRaises(errors.FailedPreconditionError): sess.run(get_next) # Initialize with one dataset. sess.run(dataset_3_init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Initialize with a different dataset. sess.run(dataset_4_init_op) self.assertAllEqual([4, 5, 6, 7], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Reinitialize with the first dataset. sess.run(dataset_3_init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @test_util.run_deprecated_v1 def testReinitializableIteratorWithFunctions(self): def g(): for i in range(10): yield i iterator = iterator_ops.Iterator.from_structure(dtypes.int64, []) next_element = iterator.get_next() with self.cached_session() as sess: dataset_1 = dataset_ops.Dataset.from_generator( g, output_types=dtypes.int64) sess.run(iterator.make_initializer(dataset_1)) for expected in range(10): self.assertEqual(expected, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) dataset_2 = dataset_ops.Dataset.from_generator( g, output_types=dtypes.int64) sess.run(iterator.make_initializer(dataset_2)) for expected in range(10): self.assertEqual(expected, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) def testReinitializableIteratorStaticErrors(self): # Non-matching structure for types and shapes. with self.assertRaises(TypeError): iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), [None]) # Test validation of dataset argument. iterator = iterator_ops.Iterator.from_structure((dtypes.int64, dtypes.float64)) # Incompatible structure. with self.assertRaises(ValueError): iterator.make_initializer( dataset_ops.Dataset.from_tensors(((constant_op.constant( [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant( [4., 5., 6., 7.], dtype=dtypes.float64),)))) # Incompatible types. with self.assertRaises(TypeError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( (constant_op.constant([1, 2, 3], dtype=dtypes.int32), constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32)))) # Incompatible shapes. iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), ([None], [])) with self.assertRaises(TypeError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( (constant_op.constant([1, 2, 3], dtype=dtypes.int64), constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64)))) @test_util.run_deprecated_v1 def testIteratorStringHandle(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_3.output_types, dataset_3.output_shapes) next_element = feedable_iterator.get_next() self.assertEqual(dataset_3.output_types, feedable_iterator.output_types) self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) self.assertEqual([], feedable_iterator.output_shapes) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual(10, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual(1, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual(20, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual(2, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual(30, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual(3, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual(40, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle}) @test_util.run_deprecated_v1 def testIteratorStringHandleFuture(self): with forward_compat.forward_compatibility_horizon(2018, 8, 4): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_3.output_types, dataset_3.output_shapes) next_element = feedable_iterator.get_next() self.assertEqual(dataset_3.output_types, feedable_iterator.output_types) self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) self.assertEqual([], feedable_iterator.output_shapes) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual( 10, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 1, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 20, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 2, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 30, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 3, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 40, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle}) @test_util.run_deprecated_v1 def testIteratorStringHandleReuseTensorObject(self): dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset) initializable_iterator = dataset_ops.make_initializable_iterator(dataset) structure_iterator = iterator_ops.Iterator.from_structure( dataset.output_types) created_ops = len(ops.get_default_graph().get_operations()) self.assertIs(one_shot_iterator.string_handle(), one_shot_iterator.string_handle()) self.assertIs(initializable_iterator.string_handle(), initializable_iterator.string_handle()) self.assertIs(structure_iterator.string_handle(), structure_iterator.string_handle()) # Assert that getting the (default) string handle creates no ops. self.assertEqual(created_ops, len(ops.get_default_graph().get_operations())) # Specifying an explicit name will create a new op. handle_with_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo", handle_with_name.op.name) self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name) handle_with_same_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo_1", handle_with_same_name.op.name) self.assertIsNot(handle_with_name, handle_with_same_name) @test_util.run_deprecated_v1 def testIteratorStringHandleError(self): dataset_int_scalar = ( dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat()) dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_int_scalar = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, []) feedable_int_vector = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, [None]) feedable_int_any = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32) with self.cached_session() as sess: handle_int_scalar = sess.run(dataset_ops.make_one_shot_iterator( dataset_int_scalar).string_handle()) handle_float_vector = sess.run(dataset_ops.make_one_shot_iterator( dataset_float_vector).string_handle()) self.assertEqual(1, sess.run( feedable_int_scalar.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) self.assertEqual(2, sess.run( feedable_int_any.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) with self.assertRaises(errors.InvalidArgumentError): print(sess.run( feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) with self.assertRaises(errors.InvalidArgumentError): print(sess.run( feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_float_vector})) @test_util.run_deprecated_v1 def testRemoteIteratorUsingRemoteCallOpDirectSession(self): worker_config = config_pb2.ConfigProto() worker_config.device_count["CPU"] = 3 with ops.device("/job:localhost/replica:0/task:0/cpu:1"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_3_handle = iterator_3.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( h, dataset_3.output_types, dataset_3.output_shapes) return remote_iterator.get_next() with ops.device("/job:localhost/replica:0/task:0/cpu:0"): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) remote_op = functional_ops.remote_call( args=[iterator_3_handle], Tout=[dtypes.int32], f=_remote_fn, target=target_placeholder) with self.session(config=worker_config) as sess: elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [1]) # Fails when target is cpu:2 where the resource is not located. with self.assertRaises(errors.InvalidArgumentError): sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" }) elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [2]) elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) @test_util.run_deprecated_v1 def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self): s1 = server_lib.Server.create_local_server() s2 = server_lib.Server.create_local_server() s3 = server_lib.Server.create_local_server() cluster_def = cluster_pb2.ClusterDef() workers = cluster_def.job.add() workers.name = "worker" workers.tasks[0] = s1.target[len("grpc://"):] workers.tasks[1] = s2.target[len("grpc://"):] client = cluster_def.job.add() client.name = "client" client.tasks[0] = s3.target[len("grpc://"):] config = config_pb2.ConfigProto(cluster_def=cluster_def) worker_devices = [ "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2) ] itr_handles = [] for device in worker_devices: with ops.device(device): src = dataset_ops.Dataset.from_tensor_slices([device]) itr = dataset_ops.make_one_shot_iterator(src) itr_handles.append(itr.string_handle()) targets = dataset_ops.Dataset.from_tensor_slices(worker_devices) handles = dataset_ops.Dataset.from_tensor_slices(itr_handles) @function.Defun(dtypes.string) def loading_func(h): remote_itr = iterator_ops.Iterator.from_string_handle( h, itr.output_types, itr.output_shapes) return remote_itr.get_next() def map_fn(target, handle): return functional_ops.remote_call( args=[handle], Tout=[dtypes.string], f=loading_func, target=target) with ops.device("/job:client"): client_dataset = dataset_ops.Dataset.zip((targets, handles)).map(map_fn) itr = dataset_ops.make_initializable_iterator(client_dataset) n = itr.get_next() with session.Session(s3.target, config=config) as sess: sess.run(itr.initializer) expected_values = worker_devices for expected in expected_values: self.assertEqual((compat.as_bytes(expected),), sess.run(n)) with self.assertRaises(errors.OutOfRangeError): sess.run(n) def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/job:localhost/replica:0/task:0/cpu:0"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_3_handle = iterator_3.string_handle() def _encode_raw(byte_array): return bytes(bytearray(byte_array)) @function.Defun(dtypes.uint8) def _remote_fn(h): handle = script_ops.py_func(_encode_raw, [h], dtypes.string) remote_iterator = iterator_ops.Iterator.from_string_handle( handle, dataset_3.output_types, dataset_3.output_shapes) return remote_iterator.get_next() with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) iterator_3_handle_uint8 = parsing_ops.decode_raw( bytes=iterator_3_handle, out_type=dtypes.uint8) remote_op = functional_ops.remote_call( args=[iterator_3_handle_uint8], Tout=[dtypes.int32], f=_remote_fn, target=target_placeholder) with self.cached_session() as sess: elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [1]) elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [2]) elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) @test_util.run_deprecated_v1 def testIncorrectIteratorRestore(self): def _path(): return os.path.join(self.get_temp_dir(), "iterator") def _save_op(iterator_resource): iterator_state_variant = gen_dataset_ops.serialize_iterator( iterator_resource) save_op = io_ops.write_file( _path(), parsing_ops.serialize_tensor(iterator_state_variant)) return save_op def _restore_op(iterator_resource): iterator_state_variant = parsing_ops.parse_tensor( io_ops.read_file(_path()), dtypes.variant) restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, iterator_state_variant) return restore_op def _build_range_dataset_graph(): start = 1 stop = 10 iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(start, stop)) init_op = iterator.initializer get_next = iterator.get_next() save_op = _save_op(iterator._iterator_resource) restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op def _build_reader_dataset_graph(): filenames = ["test"] # Does not exist but we don't care in this test. iterator = dataset_ops.make_initializable_iterator( readers.FixedLengthRecordDataset(filenames, 1, 0, 0)) init_op = iterator.initializer get_next_op = iterator.get_next() save_op = _save_op(iterator._iterator_resource) restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op # Saving iterator for RangeDataset graph. with ops.Graph().as_default() as g: init_op, _, save_op, _ = _build_range_dataset_graph() with self.session(graph=g) as sess: sess.run(init_op) sess.run(save_op) # Attempt to restore the saved iterator into an IteratorResource of # incompatible type. An iterator of RangeDataset has output type int64, # while an iterator of FixedLengthRecordDataset has output type string. # So an InvalidArgumentError should be raised by # IteratorResource::set_iterator. with ops.Graph().as_default() as g: _, _, _, restore_op = _build_reader_dataset_graph() with self.session(graph=g) as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(restore_op) @test_util.run_deprecated_v1 def testRepeatedGetNextWarning(self): iterator = dataset_ops.make_one_shot_iterator(dataset_ops.Dataset.range(10)) warnings.simplefilter("always") with warnings.catch_warnings(record=True) as w: for _ in range(100): iterator.get_next() self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, len(w)) for warning in w: self.assertIn( iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE, str(warning.message)) def testEagerIteratorAsync(self): with context.eager_mode(), context.execution_mode(context.ASYNC): val = 0 dataset = dataset_ops.Dataset.range(10) for foo in dataset: self.assertEqual(val, foo.numpy()) val += 1 # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, []), ops.Tensor, dtypes.float32, []), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure(dtypes.int32, [1]), sparse_tensor.SparseTensor, dtypes.int32, [1]), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))}, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, []))}), {"a": ops.Tensor, "b": (ops.Tensor, ops.Tensor)}, {"a": dtypes.float32, "b": (dtypes.string, dtypes.string)}, {"a": [], "b": ([1], [])}), ) def testIteratorStructure(self, tf_value_fn, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue(expected_element_structure.is_compatible_with( iterator._element_structure)) self.assertTrue(iterator._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual(expected_output_classes, iterator.output_classes) self.assertEqual(expected_output_types, iterator.output_types) self.assertEqual(expected_output_shapes, iterator.output_shapes) def testIteratorGetNextName(self): with ops.Graph().as_default(): iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(37.0)) next_element = iterator.get_next(name="overridden_name") self.assertEqual("overridden_name", next_element.op.name)
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) graph = graph_pb2.GraphDef().FromString( self.evaluate(dataset._as_serialized_graph())) self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave(lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", lambda: readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", lambda: dataset_ops.Dataset.from_generator( DatasetTest.make_gen(), dtypes.int32), 1), ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])), ("Range", lambda: dataset_ops.Dataset.range(10)), ("TextLine", lambda: readers.TextLineDataset("")), ("TFRecord", lambda: readers.TFRecordDataset(""), 1), ) def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0): self.assertEqual(num_inputs, len(dataset_fn()._inputs())) def testDatasetComplexSourceInputs(self): dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1]))) self.assertEqual(0, len(dataset_fn._inputs())) @parameterized.named_parameters( ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), lambda: dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), lambda: dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), lambda: dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), lambda: dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), lambda: dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), lambda: dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn): input_dataset = input_dataset_fn() self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testUnaryTransformationInputsApply(self): input_dataset = dataset_ops.Dataset.range(0) dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0)) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2 ], lambda: dataset_ops.Dataset.range(0)), ("Interleave", [lambda: dataset_ops.Dataset.range(0), None ], lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args, input_dataset_fn): input_dataset = input_dataset_fn() dataset_fn = self.make_interleave_fn(*interleave_fn_args) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), lambda: dataset_ops.Dataset.range(0), lambda: dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn): input1 = input1_fn() input2 = input2_fn() self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))), ) def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual(nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure( dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, [])) })), ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetStructure( structure.TensorStructure(dtypes.int32, []))), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testDatasetStructure(self, tf_value_fn, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map( lambda _: tf_value_fn()) dataset_structure = structure.Structure.from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure) # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing # the element structure. self.assertTrue( expected_element_structure.is_compatible_with( dataset_structure._element_structure)) self.assertTrue( dataset_structure._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual([dtypes.variant], dataset_structure._flat_types) self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value_fn() if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces(round_trip_dataset, [self.evaluate(tf_value_fn())], requires_initialization=True) # NOTE: This test is specific to graph mode and is skipped in eager mode. @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorOneShot(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2) # NOTE: This test is specific to graph mode and is skipped in eager mode. @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorOneShotSimple(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset.make_one_shot_iterator() self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") # NOTE: This test is specific to graph mode and is skipped in eager mode. @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorInitializable(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2)
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): def testFromValue(self): opt = optional_ops.Optional.from_value(constant_op.constant(37.0)) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual(37.0, self.evaluate(opt.get_value())) def testFromStructuredValue(self): opt = optional_ops.Optional.from_value({ "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual({ "a": 37.0, "b": ([b"Foo"], b"Bar") }, self.evaluate(opt.get_value())) def testFromSparseTensor(self): st_0 = sparse_tensor.SparseTensorValue(indices=np.array([[0]]), values=np.array([0], dtype=np.int64), dense_shape=np.array([1])) st_1 = sparse_tensor.SparseTensorValue( indices=np.array([[0, 0], [1, 1]]), values=np.array([-1., 1.], dtype=np.float32), dense_shape=np.array([2, 2])) opt = optional_ops.Optional.from_value((st_0, st_1)) self.assertTrue(self.evaluate(opt.has_value())) val_0, val_1 = opt.get_value() for expected, actual in [(st_0, val_0), (st_1, val_1)]: self.assertAllEqual(expected.indices, self.evaluate(actual.indices)) self.assertAllEqual(expected.values, self.evaluate(actual.values)) self.assertAllEqual(expected.dense_shape, self.evaluate(actual.dense_shape)) def testFromNone(self): value_structure = structure.TensorStructure(dtypes.float32, []) opt = optional_ops.Optional.none_from_structure(value_structure) self.assertTrue( opt.value_structure.is_compatible_with(value_structure)) self.assertFalse( opt.value_structure.is_compatible_with( structure.TensorStructure(dtypes.float32, [1]))) self.assertFalse( opt.value_structure.is_compatible_with( structure.TensorStructure(dtypes.int32, []))) self.assertFalse(self.evaluate(opt.has_value())) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(opt.get_value()) def testAddN(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): # With value opt1 = optional_ops.Optional.from_value((1.0, 2.0)) opt2 = optional_ops.Optional.from_value((3.0, 4.0)) add_tensor = math_ops.add_n( [opt1._variant_tensor, opt2._variant_tensor]) add_opt = optional_ops._OptionalImpl(add_tensor, opt1.value_structure) self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0)) # Without value opt_none1 = optional_ops.Optional.none_from_structure( opt1.value_structure) opt_none2 = optional_ops.Optional.none_from_structure( opt2.value_structure) add_tensor = math_ops.add_n( [opt_none1._variant_tensor, opt_none2._variant_tensor]) add_opt = optional_ops._OptionalImpl(add_tensor, opt_none1.value_structure) self.assertFalse(self.evaluate(add_opt.has_value())) def testNestedAddN(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): opt1 = optional_ops.Optional.from_value([1, 2.0]) opt2 = optional_ops.Optional.from_value([3, 4.0]) opt3 = optional_ops.Optional.from_value( (5.0, opt1._variant_tensor)) opt4 = optional_ops.Optional.from_value( (6.0, opt2._variant_tensor)) add_tensor = math_ops.add_n( [opt3._variant_tensor, opt4._variant_tensor]) add_opt = optional_ops._OptionalImpl(add_tensor, opt3.value_structure) self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0) inner_add_opt = optional_ops._OptionalImpl( add_opt.get_value()[1], opt1.value_structure) self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0]) def testZerosLike(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): # With value opt = optional_ops.Optional.from_value((1.0, 2.0)) zeros_tensor = array_ops.zeros_like(opt._variant_tensor) zeros_opt = optional_ops._OptionalImpl(zeros_tensor, opt.value_structure) self.assertAllEqual(self.evaluate(zeros_opt.get_value()), (0.0, 0.0)) # Without value opt_none = optional_ops.Optional.none_from_structure( opt.value_structure) zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor) zeros_opt = optional_ops._OptionalImpl( zeros_tensor, opt_none.value_structure) self.assertFalse(self.evaluate(zeros_opt.has_value())) def testNestedZerosLike(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): opt1 = optional_ops.Optional.from_value(1.0) opt2 = optional_ops.Optional.from_value(opt1._variant_tensor) zeros_tensor = array_ops.zeros_like(opt2._variant_tensor) zeros_opt = optional_ops._OptionalImpl(zeros_tensor, opt2.value_structure) inner_zeros_opt = optional_ops._OptionalImpl( zeros_opt.get_value(), opt1.value_structure) self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0) def testCopyToGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/cpu:0"): optional_with_value = optional_ops.Optional.from_value( (constant_op.constant(37.0), constant_op.constant("Foo"), constant_op.constant(42))) optional_none = optional_ops.Optional.none_from_structure( structure.TensorStructure(dtypes.float32, [])) with ops.device("/gpu:0"): gpu_optional_with_value = optional_ops._OptionalImpl( array_ops.identity(optional_with_value._variant_tensor), optional_with_value.value_structure) gpu_optional_none = optional_ops._OptionalImpl( array_ops.identity(optional_none._variant_tensor), optional_none.value_structure) gpu_optional_with_value_has_value = gpu_optional_with_value.has_value( ) gpu_optional_with_value_values = gpu_optional_with_value.get_value( ) gpu_optional_none_has_value = gpu_optional_none.has_value() self.assertTrue(self.evaluate(gpu_optional_with_value_has_value)) self.assertEqual((37.0, b"Foo", 42), self.evaluate(gpu_optional_with_value_values)) self.assertFalse(self.evaluate(gpu_optional_none_has_value)) def testNestedCopyToGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/cpu:0"): optional_with_value = optional_ops.Optional.from_value( (constant_op.constant(37.0), constant_op.constant("Foo"), constant_op.constant(42))) optional_none = optional_ops.Optional.none_from_structure( structure.TensorStructure(dtypes.float32, [])) nested_optional = optional_ops.Optional.from_value( (optional_with_value._variant_tensor, optional_none._variant_tensor, 1.0)) with ops.device("/gpu:0"): gpu_nested_optional = optional_ops._OptionalImpl( array_ops.identity(nested_optional._variant_tensor), nested_optional.value_structure) gpu_nested_optional_has_value = gpu_nested_optional.has_value() gpu_nested_optional_values = gpu_nested_optional.get_value() self.assertTrue(self.evaluate(gpu_nested_optional_has_value)) inner_with_value = optional_ops._OptionalImpl( gpu_nested_optional_values[0], optional_with_value.value_structure) inner_none = optional_ops._OptionalImpl(gpu_nested_optional_values[1], optional_none.value_structure) self.assertEqual((37.0, b"Foo", 42), self.evaluate(inner_with_value.get_value())) self.assertFalse(self.evaluate(inner_none.has_value())) self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2])) def _assertElementValueEqual(self, expected, actual): if isinstance(expected, dict): self.assertItemsEqual(list(expected.keys()), list(actual.keys())) for k in expected.keys(): self._assertElementValueEqual(expected[k], actual[k]) elif isinstance(expected, sparse_tensor.SparseTensorValue): self.assertAllEqual(expected.indices, actual.indices) self.assertAllEqual(expected.values, actual.values) self.assertAllEqual(expected.dense_shape, actual.dense_shape) else: self.assertAllEqual(expected, actual) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure( dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, [])) })), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testOptionalStructure(self, tf_value_fn, expected_value_structure): tf_value = tf_value_fn() opt = optional_ops.Optional.from_value(tf_value) self.assertTrue( expected_value_structure.is_compatible_with(opt.value_structure)) self.assertTrue( opt.value_structure.is_compatible_with(expected_value_structure)) opt_structure = structure.Structure.from_value(opt) self.assertIsInstance(opt_structure, optional_ops.OptionalStructure) self.assertTrue(opt_structure.is_compatible_with(opt_structure)) self.assertTrue( opt_structure._value_structure.is_compatible_with( expected_value_structure)) self.assertEqual([dtypes.variant], opt_structure._flat_types) self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes) # All OptionalStructure objects are not compatible with a non-optional # value. non_optional_structure = structure.Structure.from_value( constant_op.constant(42.0)) self.assertFalse( opt_structure.is_compatible_with(non_optional_structure)) # Assert that the optional survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_opt = opt_structure._from_tensor_list( opt_structure._to_tensor_list(opt)) if isinstance(tf_value, optional_ops.Optional): self.assertEqual( self.evaluate(tf_value.get_value()), self.evaluate(round_trip_opt.get_value().get_value())) else: self.assertEqual(self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value())) # NOTE: This test is specific to graph mode and is skipped in eager mode. @parameterized.named_parameters( ("Tensor", np.array([1, 2, 3], dtype=np.int32), lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True), ("SparseTensor", sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]], values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]), lambda: sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]), False), ("Nest", { "a": np.array([1, 2, 3], dtype=np.int32), "b": sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]], values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]) }, lambda: { "a": constant_op.constant([4, 5, 6], dtype=dtypes.int32), "b": sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]) }, False), ) @test_util.run_deprecated_v1 def testSkipEagerIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu): if not works_on_gpu and test.is_gpu_available(): self.skipTest("Test case not yet supported on GPU.") ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3) iterator = ds.make_initializable_iterator() next_elem = iterator_ops.get_next_as_optional(iterator) self.assertIsInstance(next_elem, optional_ops.Optional) self.assertTrue( next_elem.value_structure.is_compatible_with( structure.Structure.from_value(tf_value_fn()))) elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() with self.cached_session() as sess: # Before initializing the iterator, evaluating the optional fails with # a FailedPreconditionError. with self.assertRaises(errors.FailedPreconditionError): sess.run(elem_has_value_t) with self.assertRaises(errors.FailedPreconditionError): sess.run(elem_value_t) # For each element of the dataset, assert that the optional evaluates to # the expected value. sess.run(iterator.initializer) for _ in range(3): elem_has_value, elem_value = sess.run( [elem_has_value_t, elem_value_t]) self.assertTrue(elem_has_value) self._assertElementValueEqual(np_value, elem_value) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. for _ in range(2): self.assertFalse(sess.run(elem_has_value_t)) with self.assertRaises(errors.InvalidArgumentError): sess.run(elem_value_t) def testFunctionBoundaries(self): @def_function.function def get_optional(): x = constant_op.constant(1.0) opt = optional_ops.Optional.from_value(x) # TODO(skyewm): support returning Optionals from functions? return opt._variant_tensor # TODO(skyewm): support Optional arguments? @def_function.function def consume_optional(opt_tensor): value_structure = structure.TensorStructure(dtypes.float32, []) opt = optional_ops._OptionalImpl(opt_tensor, value_structure) return opt.get_value() opt_tensor = get_optional() val = consume_optional(opt_tensor) self.assertEqual(self.evaluate(val), 1.0)
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) graph = graph_pb2.GraphDef().FromString( self.evaluate(dataset._as_serialized_graph())) self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) def testAsFunctionWithMap(self): if not context.executing_eagerly(): self.skipTest("Only works executing eagerly") with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).map( lambda x: x * 2) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = _RevivedDataset( variant, original_dataset._element_structure) self.assertDatasetProduces(revived_dataset, range(0, 10, 2)) def testAsFunctionWithMapInFlatMap(self): if not context.executing_eagerly(): self.skipTest("Only works executing eagerly") with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).flat_map( lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2)) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = _RevivedDataset( variant, original_dataset._element_structure) self.assertDatasetProduces(revived_dataset, list(original_dataset)) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave(lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", lambda: readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", lambda: dataset_ops.Dataset.from_generator( DatasetTest.make_gen(), dtypes.int32), 1), ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])), ("Range", lambda: dataset_ops.Dataset.range(10)), ("TextLine", lambda: readers.TextLineDataset("")), ("TFRecord", lambda: readers.TFRecordDataset(""), 1), ) def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0): self.assertLen(dataset_fn()._inputs(), num_inputs) @test_util.run_v1_only("deprecated API, no eager or V2 test coverage") def testDatasetComplexSourceInputs(self): dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1]))) self.assertEmpty(dataset_fn._inputs()) @parameterized.named_parameters( ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), lambda: dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), lambda: dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), lambda: dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), lambda: dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), lambda: dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), lambda: dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn): input_dataset = input_dataset_fn() self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testUnaryTransformationInputsApply(self): input_dataset = dataset_ops.Dataset.range(0) dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0)) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2 ], lambda: dataset_ops.Dataset.range(0)), ("Interleave", [lambda: dataset_ops.Dataset.range(0), None ], lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args, input_dataset_fn): input_dataset = input_dataset_fn() dataset_fn = self.make_interleave_fn(*interleave_fn_args) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testNoWarnings(self): with test.mock.patch.object(warnings, "warn") as mock_log: dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10)) dataset_fn(dataset_ops.Dataset.range(10)) self.assertEmpty(mock_log.call_args_list) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), lambda: dataset_ops.Dataset.range(0), lambda: dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn): input1 = input1_fn() input2 = input2_fn() self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))), ) def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual(nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testFunctions(self): dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2) self.assertLen(dataset._functions(), 1) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure( dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, [])) })), ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetStructure( structure.TensorStructure(dtypes.int32, []))), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testDatasetStructure(self, tf_value_fn, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map( lambda _: tf_value_fn()) dataset_structure = structure.Structure.from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure) # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing # the element structure. self.assertTrue( expected_element_structure.is_compatible_with( dataset_structure._element_structure)) self.assertTrue( dataset_structure._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual([dtypes.variant], dataset_structure._flat_types) self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value_fn() if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces(round_trip_dataset, [self.evaluate(tf_value_fn())], requires_initialization=True) @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorOneShot(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2) @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorOneShotSimple(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset_ops.make_one_shot_iterator(dataset) self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorInitializable(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2) @parameterized.named_parameters( ("Async", context.ASYNC), ("Sync", context.SYNC), ) def testDatasetEagerIteration(self, execution_mode): with context.eager_mode(), context.execution_mode(execution_mode): val = 0 dataset = dataset_ops.Dataset.range(10) for foo in dataset: self.assertEqual(val, foo.numpy()) val += 1 def testDatasetAsFunctionArgument(self): @def_function.function def _uses_dataset(d): accumulator = array_ops.zeros([], dtype=dtypes.int64) for value in d: accumulator += value return accumulator with ops.device("CPU"): first_dataset = dataset_ops.Dataset.range(10) self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset))) second_dataset = dataset_ops.Dataset.range(11) self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset))) first_concrete = _uses_dataset.get_concrete_function(first_dataset) # The dataset should not be a captured input self.assertEmpty(first_concrete.graph.captures) # The two datasets have the same structure and so should re-use a trace. self.assertIs(first_concrete, _uses_dataset.get_concrete_function(second_dataset)) # With a different structure we should use a different trace. self.assertIsNot( first_concrete, _uses_dataset.get_concrete_function( dataset_ops.Dataset.zip((first_dataset, second_dataset)))) def testLimitedRetracingWithCompositeTensors(self): trace_count = [0] @def_function.function def f(ds): trace_count[0] += 1 counter = np.int64(0) for elem in ds: counter += elem return counter dataset = dataset_ops.Dataset.range(5) dataset2 = dataset_ops.Dataset.range(10) for _ in range(10): self.assertEqual(self.evaluate(f(dataset)), 10) self.assertEqual(self.evaluate(f(dataset2)), 45) self.assertEqual(trace_count[0], 1)
def __init__(self, input_dataset, functions, ratio_numerator=1, ratio_denominator=1, num_elements_per_branch=None): """Chooses the fastest of some dataset functions. Given dataset functions that take input_dataset as input and output another dataset, produces elements as quickly as the fastest of these output datasets. Note that datasets in the dataset functions are assumed to be stateless, and the iterators created by the functions' output datasets will, given the same input elements, all produce the same output elements. Datasets in the functions are also expected to iterate over the input dataset at most once. The violation of these conditions may lead to undefined behavior. For example: ```python dataset = tf.data.Dataset.range(100) dataset = _ChooseFastestDataset( dataset, [ lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10), lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1])) ], ratio=10, num_elements_per_branch=10 ) ``` The resulting dataset will produce elements equivalent to `tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or `tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))` Note that the first `num_elements_per_branch` iterations may be slower due to the overhead of dynamically picking the fastest dataset. Namely, for these iterations, the dataset will produce elements from any of branches to determine which input is the fastest. For all subsequent iterations, that input will be used. Args: input_dataset: A `Dataset` that can be used as input to `functions`. functions: A list of callables, each of which takes a `Dataset` as input and returns a `Dataset`. ratio_numerator: The numerator in the ratio of input elements consumed to output elements produced for each function. This should be the same for all functions. For example, if the function is `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset must produce 10 elements for every element of the output dataset. In this case, ratio_numerator should be 10. ratio_denominator: The denominator in the ratio of input elements consumed to output elements produced for each function. This should be the same for all functions. For example, if the function is `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset must produce 10 elements for every element of the output dataset. In this case, ratio_denominator should be 1. num_elements_per_branch: The number of elements to get from each branch before deciding which dataset is fastest. In the first len(functions) * num_elements_per_branch iterations, the dataset will call from one of the branches, and update its knowledge of which input is the fastest. Note that (num_elements_per_branch * ratio) is expected to be an integer. Returns: A `Dataset` that has the same elements the inputs. """ nested_structure = structure_lib.NestedStructure( dataset_ops.DatasetStructure( structure_lib.convert_legacy_structure( input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes))) self._funcs = [ dataset_ops.StructuredFunctionWrapper( f, "ChooseFastestV2", input_structure=nested_structure) for f in functions ] self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access self._captured_arguments = [] for f in self._funcs: self._captured_arguments.extend(f.function.captured_inputs) self._capture_lengths = [ len(f.function.captured_inputs) for f in self._funcs ] if ratio_numerator <= 0 or ratio_denominator <= 0: raise ValueError("ratio must be positive.") if num_elements_per_branch is None: # Pick a sensible default based on `ratio_denominator` num_elements_per_branch = 10 * ratio_denominator variant_tensor = ( gen_experimental_dataset_ops.choose_fastest_branch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access ratio_numerator=ratio_numerator, ratio_denominator=ratio_denominator, other_arguments=self._captured_arguments, num_elements_per_branch=num_elements_per_branch, branches=[f.function for f in self._funcs], other_arguments_lengths=self._capture_lengths, **dataset_ops.flat_structure(self))) super(_ChooseFastestBranchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, initial_state, scan_func): """See `scan()` for details.""" self._input_dataset = input_dataset self._initial_state = structure.normalize_tensors(initial_state) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_structure = type_spec.type_spec_from_value(self._initial_state) # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( scan_func, self._transformation_name(), input_structure=structure.NestedStructure( (self._state_structure, input_dataset._element_structure)), # pylint: disable=protected-access add_to_graph=False) if not ( isinstance(wrapped_func.output_types, collections.Sequence) and len(wrapped_func.output_types) == 2): raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") new_state_classes, self._output_classes = wrapped_func.output_classes # Extract and validate class information from the returned values. new_state_classes, output_classes = wrapped_func.output_classes old_state_classes = self._state_structure._to_legacy_output_classes() # pylint: disable=protected-access for new_state_class, old_state_class in zip( nest.flatten(new_state_classes), nest.flatten(old_state_classes)): if not issubclass(new_state_class, old_state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (old_state_classes, new_state_classes)) # Extract and validate type information from the returned values. new_state_types, output_types = wrapped_func.output_types old_state_types = self._state_structure._to_legacy_output_types() # pylint: disable=protected-access for new_state_type, old_state_type in zip( nest.flatten(new_state_types), nest.flatten(old_state_types)): if new_state_type != old_state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (old_state_types, new_state_types)) # Extract shape information from the returned values. new_state_shapes, output_shapes = wrapped_func.output_shapes old_state_shapes = self._state_structure._to_legacy_output_shapes() # pylint: disable=protected-access self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) flat_state_shapes = nest.flatten(old_state_shapes) flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: # TODO(b/110122868): Support a "most specific compatible structure" # method for combining structures, to avoid using legacy structures # in this method. self._state_structure = structure.convert_legacy_structure( old_state_types, nest.pack_sequence_as(old_state_shapes, weakened_state_shapes), old_state_classes) self._scan_func = wrapped_func self._scan_func.function.add_to_graph(ops.get_default_graph()) # pylint: disable=protected-access variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset( self._input_dataset._variant_tensor, self._state_structure._to_tensor_list(self._initial_state), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, **dataset_ops.flat_structure(self)) super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase, ragged_test_util.RaggedTensorTestCase): # pylint: disable=g-long-lambda,protected-access @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32], [[]]), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3,), size=0), structure.TensorArrayStructure, [dtypes.variant], [None, 3]), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), structure.SparseTensorStructure, [dtypes.variant], [None]), ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [4]]), structure.RaggedTensorStructure, [dtypes.variant], [None]), ("Nested_0", lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), ("Nested_1", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), ("Nested_2", lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, structure.NestedStructure, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None]), ) def testFlatStructure(self, value_fn, expected_structure, expected_types, expected_shapes): value = value_fn() s = structure.Structure.from_value(value) self.assertIsInstance(s, expected_structure) self.assertEqual(expected_types, s._flat_types) for expected, actual in zip(expected_shapes, s._flat_shapes): self.assertTrue(actual.is_compatible_with(expected)) self.assertTrue( tensor_shape.as_shape(expected).is_compatible_with(actual)) @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), lambda: [ constant_op.constant(38.0), array_ops.placeholder(dtypes.float32), variables.Variable(100.0), 42.0, np.array(42.0, dtype=np.float32) ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3,), size=0), lambda: [ tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3,), size=0), tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3,), size=10) ], lambda: [ tensor_array_ops.TensorArray( dtype=dtypes.int32, element_shape=(3,), size=0), tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(), size=0) ]), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [ sparse_tensor.SparseTensor( indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), sparse_tensor.SparseTensorValue( indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), array_ops.sparse_placeholder(dtype=dtypes.int32), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) ], lambda: [ constant_op.constant(37, shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), array_ops.sparse_placeholder( dtype=dtypes.int32, shape=[None, None, None]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) ]), ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), lambda: [ ragged_factory_ops.constant([[1, 2], [3, 4], []]), ragged_factory_ops.constant([[1], [2, 3, 4], [5]]), ], lambda: [ ragged_factory_ops.constant(1), ragged_factory_ops.constant([1, 2]), ragged_factory_ops.constant([[1], [2]]), ragged_factory_ops.constant([["a", "b"]]), ]), ("Nested", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6]) }], lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6, 7]) }, { "a": constant_op.constant(15), "b": constant_op.constant([4, 5, 6]) }, { "a": constant_op.constant(15), "b": sparse_tensor.SparseTensor( indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), ) @test_util.run_deprecated_v1 def testIsCompatibleWithStructure( self, original_value_fn, compatible_values_fn, incompatible_values_fn): original_value = original_value_fn() compatible_values = compatible_values_fn() incompatible_values = incompatible_values_fn() s = structure.Structure.from_value(original_value) for compatible_value in compatible_values: self.assertTrue( s.is_compatible_with( structure.Structure.from_value(compatible_value))) for incompatible_value in incompatible_values: self.assertFalse( s.is_compatible_with( structure.Structure.from_value(incompatible_value))) @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3,), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3,), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.int32, element_shape=(), size=0)), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor( indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor( indices=[[3]], values=[-1], dense_shape=[5]), lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])), ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]), lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]), lambda: ragged_factory_ops.constant([[[1]], [[2], [3]]], ragged_rank=1), lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]), lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])), ("Nested", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3])}, lambda: { "a": constant_op.constant(42.0), "b": constant_op.constant([4, 5, 6])}, lambda: { "a": constant_op.constant([1, 2, 3]), "b": constant_op.constant(37.0) }), ) # pyformat: disable def testStructureFromValueEquality(self, value1_fn, value2_fn, *not_equal_value_fns): # pylint: disable=g-generic-assert s1 = structure.Structure.from_value(value1_fn()) s2 = structure.Structure.from_value(value2_fn()) self.assertEqual(s1, s1) # check __eq__ operator. self.assertEqual(s1, s2) # check __eq__ operator. self.assertFalse(s1 != s1) # check __ne__ operator. self.assertFalse(s1 != s2) # check __ne__ operator. self.assertEqual(hash(s1), hash(s1)) self.assertEqual(hash(s1), hash(s2)) for value_fn in not_equal_value_fns: s3 = structure.Structure.from_value(value_fn()) self.assertNotEqual(s1, s3) # check __ne__ operator. self.assertNotEqual(s2, s3) # check __ne__ operator. self.assertFalse(s1 == s3) # check __eq_ operator. self.assertFalse(s2 == s3) # check __eq_ operator. @parameterized.named_parameters( ("RaggedTensor_RaggedRank", structure.RaggedTensorStructure(dtypes.int32, None, 1), structure.RaggedTensorStructure(dtypes.int32, None, 2)), ("RaggedTensor_Shape", structure.RaggedTensorStructure(dtypes.int32, [3, None], 1), structure.RaggedTensorStructure(dtypes.int32, [5, None], 1)), ("RaggedTensor_DType", structure.RaggedTensorStructure(dtypes.int32, None, 1), structure.RaggedTensorStructure(dtypes.float32, None, 1)), ) def testInequality(self, s1, s2): # pylint: disable=g-generic-assert self.assertNotEqual(s1, s2) # check __ne__ operator. self.assertFalse(s1 == s2) # check __eq__ operator. @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3,), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3,), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.int32, element_shape=(), size=0)), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor( indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor(indices=[[3]], values=[-1], dense_shape=[5])), ("Nested", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: { "a": constant_op.constant(42.0), "b": constant_op.constant([4, 5, 6]) }, lambda: { "a": constant_op.constant([1, 2, 3]), "b": constant_op.constant(37.0) }), ) def testHash(self, value1_fn, value2_fn, value3_fn): s1 = structure.Structure.from_value(value1_fn()) s2 = structure.Structure.from_value(value2_fn()) s3 = structure.Structure.from_value(value3_fn()) self.assertEqual(hash(s1), hash(s1)) self.assertEqual(hash(s1), hash(s2)) self.assertNotEqual(hash(s1), hash(s3)) self.assertNotEqual(hash(s2), hash(s3)) @parameterized.named_parameters( ( "Tensor", lambda: constant_op.constant(37.0), ), ( "SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)), ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),), ( "Nested_0", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, ), ( "Nested_1", lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, ), ) def testRoundTripConversion(self, value_fn): value = value_fn() s = structure.Structure.from_value(value) def maybe_stack_ta(v): if isinstance(v, tensor_array_ops.TensorArray): return v.stack() else: return v before = self.evaluate(maybe_stack_ta(value)) after = self.evaluate( maybe_stack_ta(s._from_tensor_list(s._to_tensor_list(value)))) flat_before = nest.flatten(before) flat_after = nest.flatten(after) for b, a in zip(flat_before, flat_after): if isinstance(b, sparse_tensor.SparseTensorValue): self.assertAllEqual(b.indices, a.indices) self.assertAllEqual(b.values, a.values) self.assertAllEqual(b.dense_shape, a.dense_shape) elif isinstance( b, (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)): self.assertRaggedEqual(b, a) else: self.assertAllEqual(b, a) # pylint: enable=g-long-lambda def testIncompatibleStructure(self): # Define three mutually incompatible values/structures, and assert that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_tensor = constant_op.constant(42.0) s_tensor = structure.Structure.from_value(value_tensor) flat_tensor = s_tensor._to_tensor_list(value_tensor) value_sparse_tensor = sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]) s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor) flat_sparse_tensor = s_sparse_tensor._to_tensor_list(value_sparse_tensor) value_nest = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_nest = structure.Structure.from_value(value_nest) flat_nest = s_nest._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Value \{.*\} is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_tensor) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*float32.* and shape \(\)"): s_tensor._from_tensor_list(flat_sparse_tensor) with self.assertRaisesRegexp( ValueError, "TensorStructure corresponds to a single tf.Tensor."): s_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_sparse_tensor) def testIncompatibleNestedStructure(self): # Define three mutually incompatible nested values/structures, and assert # that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructure a flattened value with an # incompatible structure fails. value_0 = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_0 = structure.Structure.from_value(value_0) flat_s_0 = s_0._to_tensor_list(value_0) # `value_1` has compatible nested structure with `value_0`, but different # classes. value_1 = { "a": constant_op.constant(37.0), "b": sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]) } s_1 = structure.Structure.from_value(value_1) flat_s_1 = s_1._to_tensor_list(value_1) # `value_2` has incompatible nested structure with `value_0` and `value_1`. value_2 = { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) } s_2 = structure.Structure.from_value(value_2) flat_s_2 = s_2._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure"): s_0._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*SparseTensorStructure"): s_1._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp # needs to account for "a" coming before or after "b". It might be worth # adding a deterministic repr for these error messages (among other # improvements). with self.assertRaisesRegexp( ValueError, "Tensor.*Tensor.* not compatible with the nested structure " ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"): s_2._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* " "not compatible with the nested structure .*" "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"): s_2._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"): s_0._from_tensor_list(flat_s_1) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_0._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_1._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_1._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_1) @parameterized.named_parameters( ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor, structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", dtypes.int32, tensor_shape.matrix( 2, 2), sparse_tensor.SparseTensor, structure.SparseTensorStructure(dtypes.int32, [2, 2])), ("TensorArray_0", dtypes.int32, tensor_shape.as_shape( [None, True, 2, 2]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)), ("TensorArray_1", dtypes.int32, tensor_shape.as_shape( [True, None, 2, 2]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)), ("TensorArray_2", dtypes.int32, tensor_shape.as_shape( [True, False, 2, 2]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)), ("RaggedTensor", dtypes.int32, tensor_shape.matrix(2, 2), structure.RaggedTensorStructure(dtypes.int32, [2, 2], 1), structure.RaggedTensorStructure(dtypes.int32, [2, 2], 1)), ("Nested", { "a": dtypes.float32, "b": (dtypes.int32, dtypes.string) }, { "a": tensor_shape.scalar(), "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar()) }, { "a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) })), ) def testConvertLegacyStructure(self, output_types, output_shapes, output_classes, expected_structure): actual_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) self.assertTrue(expected_structure.is_compatible_with(actual_structure)) self.assertTrue(actual_structure.is_compatible_with(expected_structure)) def testNestedNestedStructure(self): # Although `Structure.from_value()` will not construct one, a nested # structure containing nested `NestedStructure` objects can occur if a # structure is constructed manually. s = structure.NestedStructure( (structure.TensorStructure(dtypes.int64, []), structure.NestedStructure( (structure.TensorStructure(dtypes.float32, []), structure.TensorStructure(dtypes.string, []))))) int64_t = constant_op.constant(37, dtype=dtypes.int64) float32_t = constant_op.constant(42.0) string_t = constant_op.constant("Foo") nested_tensors = (int64_t, (float32_t, string_t)) tensor_list = s._to_tensor_list(nested_tensors) for expected, actual in zip([int64_t, float32_t, string_t], tensor_list): self.assertIs(expected, actual) (actual_int64_t, (actual_float32_t, actual_string_t)) = s._from_tensor_list( tensor_list) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) (actual_int64_t, (actual_float32_t, actual_string_t)) = ( s._from_compatible_tensor_list(tensor_list)) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) @parameterized.named_parameters( ("Tensor", structure.TensorStructure(dtypes.float32, []), 32, structure.TensorStructure(dtypes.float32, [32])), ("TensorUnknown", structure.TensorStructure(dtypes.float32, []), None, structure.TensorStructure(dtypes.float32, [None])), ("SparseTensor", structure.SparseTensorStructure(dtypes.float32, [None]), 32, structure.SparseTensorStructure(dtypes.float32, [32, None])), ("SparseTensorUnknown", structure.SparseTensorStructure(dtypes.float32, [4]), None, structure.SparseTensorStructure(dtypes.float32, [None, 4])), ("RaggedTensor", structure.RaggedTensorStructure(dtypes.float32, [2, None], 1), 32, structure.RaggedTensorStructure(dtypes.float32, [32, 2, None], 2)), ("RaggedTensorUnknown", structure.RaggedTensorStructure(dtypes.float32, [4, None], 1), None, structure.RaggedTensorStructure(dtypes.float32, [None, 4, None], 2)), ("Nested", structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, []))}), 128, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, [128]), "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]), structure.TensorStructure(dtypes.string, [128]))})), ) def testBatch(self, element_structure, batch_size, expected_batched_structure): batched_structure = element_structure._batch(batch_size) self.assertTrue( batched_structure.is_compatible_with(expected_batched_structure)) self.assertTrue( expected_batched_structure.is_compatible_with(batched_structure)) @parameterized.named_parameters( ("Tensor", structure.TensorStructure(dtypes.float32, [32]), structure.TensorStructure(dtypes.float32, [])), ("TensorUnknown", structure.TensorStructure(dtypes.float32, [None]), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", structure.SparseTensorStructure(dtypes.float32, [32, None]), structure.SparseTensorStructure(dtypes.float32, [None])), ("SparseTensorUnknown", structure.SparseTensorStructure(dtypes.float32, [None, 4]), structure.SparseTensorStructure(dtypes.float32, [4])), ("RaggedTensor", structure.RaggedTensorStructure(dtypes.float32, [32, 4, None], 2), structure.RaggedTensorStructure(dtypes.float32, [4, None], 1)), ("RaggedTensorUnknown", structure.RaggedTensorStructure(dtypes.float32, [None, None, 4], 2), structure.RaggedTensorStructure(dtypes.float32, [None, 4], 1)), ("Nested", structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, [128]), "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]), structure.TensorStructure(dtypes.string, [None]))}), structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, []))})), ) def testUnbatch(self, element_structure, expected_unbatched_structure): unbatched_structure = element_structure._unbatch() self.assertTrue( unbatched_structure.is_compatible_with(expected_unbatched_structure)) self.assertTrue( expected_unbatched_structure.is_compatible_with(unbatched_structure)) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), lambda: constant_op.constant([1.0, 2.0])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]), lambda: sparse_tensor.SparseTensor( indices=[[0]], values=[13], dense_shape=[2])), ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]), lambda: ragged_factory_ops.constant([[1]])), ("Nest", lambda: ( constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), sparse_tensor.SparseTensor( indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])), lambda: (constant_op.constant([1.0, 2.0]), sparse_tensor.SparseTensor( indices=[[0]], values=[13], dense_shape=[2]))), ) def testToBatchedTensorList(self, value_fn, element_0_fn): batched_value = value_fn() s = structure.Structure.from_value(batched_value) batched_tensor_list = s._to_batched_tensor_list(batched_value) # The batch dimension is 2 for all of the test cases. # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT # tensors in which we store sparse tensors. for t in batched_tensor_list: if t.dtype != dtypes.variant: self.assertEqual(2, self.evaluate(array_ops.shape(t)[0])) # Test that the 0th element from the unbatched tensor is equal to the # expected value. expected_element_0 = self.evaluate(element_0_fn()) unbatched_s = s._unbatch() actual_element_0 = unbatched_s._from_tensor_list( [t[0] for t in batched_tensor_list]) for expected, actual in zip( nest.flatten(expected_element_0), nest.flatten(actual_element_0)): if sparse_tensor.is_sparse(expected): self.assertSparseValuesEqual(expected, actual) elif ragged_tensor.is_ragged(expected): self.assertRaggedEqual(expected, actual) else: self.assertAllEqual(expected, actual)
def __init__(self, filenames, record_defaults, compression_type=None, buffer_size=None, header=False, field_delim=",", use_quote_delim=True, na_value="", select_cols=None): """Creates a `CsvDataset` by reading and decoding CSV files. The elements of this dataset correspond to records from the file(s). RFC 4180 format is expected for CSV files (https://tools.ietf.org/html/rfc4180) Note that we allow leading and trailing spaces with int or float field. For example, suppose we have a file 'my_file0.csv' with four CSV columns of different data types: ``` abcdefg,4.28E10,5.55E6,12 hijklmn,-5.3E14,,2 ``` We can construct a CsvDataset from it as follows: ```python tf.enable_eager_execution() dataset = tf.data.experimental.CsvDataset( "my_file*.csv", [tf.float32, # Required field, use dtype or empty tensor tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 tf.int32, # Required field, use dtype or empty tensor ], select_cols=[1,2,3] # Only parse last three columns ) ``` The expected output of its iterations is: ```python for element in dataset: print(element) >> (4.28e10, 5.55e6, 12) >> (-5.3e14, 0.0, 2) ``` Args: filenames: A `tf.string` tensor containing one or more filenames. record_defaults: A list of default values for the CSV fields. Each item in the list is either a valid CSV `DType` (float32, float64, int32, int64, string), or a `Tensor` object with one of the above types. One per column of CSV data, with either a scalar `Tensor` default value for the column if it is optional, or `DType` or empty `Tensor` if required. If both this and `select_columns` are specified, these must have the same lengths, and `column_defaults` is assumed to be sorted in order of increasing column index. compression_type: (Optional.) A `tf.string` scalar evaluating to one of `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression. buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes to buffer while reading files. Defaults to 4MB. header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) have header line(s) that should be skipped when parsing. Defaults to `False`. field_delim: (Optional.) A `tf.string` scalar containing the delimiter character that separates fields in a record. Defaults to `","`. use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats double quotation marks as regular characters inside of string fields (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`. na_value: (Optional.) A `tf.string` scalar indicating a value that will be treated as NA/NaN. select_cols: (Optional.) A sorted list of column indices to select from the input data. If specified, only this subset of columns will be parsed. Defaults to parsing all columns. """ self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") self._compression_type = convert.optional_param_to_tensor( "compression_type", compression_type, argument_default="", argument_dtype=dtypes.string) record_defaults = [ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x for x in record_defaults ] self._record_defaults = ops.convert_n_to_tensor( record_defaults, name="record_defaults") self._buffer_size = convert.optional_param_to_tensor( "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) self._header = ops.convert_to_tensor( header, dtype=dtypes.bool, name="header") self._field_delim = ops.convert_to_tensor( field_delim, dtype=dtypes.string, name="field_delim") self._use_quote_delim = ops.convert_to_tensor( use_quote_delim, dtype=dtypes.bool, name="use_quote_delim") self._na_value = ops.convert_to_tensor( na_value, dtype=dtypes.string, name="na_value") self._select_cols = convert.optional_param_to_tensor( "select_cols", select_cols, argument_default=[], argument_dtype=dtypes.int64, ) self._structure = structure.NestedStructure( tuple(structure.TensorStructure(d.dtype, []) for d in self._record_defaults)) variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset( filenames=self._filenames, record_defaults=self._record_defaults, buffer_size=self._buffer_size, header=self._header, output_shapes=self._structure._flat_shapes, # pylint: disable=protected-access field_delim=self._field_delim, use_quote_delim=self._use_quote_delim, na_value=self._na_value, select_cols=self._select_cols, compression_type=self._compression_type) super(CsvDatasetV2, self).__init__(variant_tensor)
def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping defun for reduce_func.""" # Iteratively rerun the reduce function until reaching a fixed point on # `self._state_structure`. self._state_structure = self._init_func.output_structure state_types = self._init_func.output_types state_shapes = self._init_func.output_shapes state_classes = self._init_func.output_classes need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( reduce_func, self._transformation_name(), input_structure=structure.NestedStructure( (self._state_structure, input_dataset._element_structure)), # pylint: disable=protected-access add_to_graph=False) # Extract and validate class information from the returned values. for new_state_class, state_class in zip( nest.flatten(wrapped_func.output_classes), nest.flatten(state_classes)): if not issubclass(new_state_class, state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, wrapped_func.output_classes)) # Extract and validate type information from the returned values. for new_state_type, state_type in zip( nest.flatten(wrapped_func.output_types), nest.flatten(state_types)): if new_state_type != state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._init_func.output_types, wrapped_func.output_types)) # Extract shape information from the returned values. flat_state_shapes = nest.flatten(state_shapes) flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: state_shapes = nest.pack_sequence_as( self._init_func.output_shapes, weakened_state_shapes) self._state_structure = structure.convert_legacy_structure( state_types, state_shapes, state_classes) self._reduce_func = wrapped_func self._reduce_func.function.add_to_graph(ops.get_default_graph())
class OptionalTest(test.TestCase, parameterized.TestCase): @test_util.run_in_graph_and_eager_modes def testFromValue(self): opt = optional_ops.Optional.from_value(constant_op.constant(37.0)) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual(37.0, self.evaluate(opt.get_value())) @test_util.run_in_graph_and_eager_modes def testFromStructuredValue(self): opt = optional_ops.Optional.from_value({ "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual({ "a": 37.0, "b": ([b"Foo"], b"Bar") }, self.evaluate(opt.get_value())) @test_util.run_in_graph_and_eager_modes def testFromSparseTensor(self): st_0 = sparse_tensor.SparseTensorValue( indices=np.array([[0]]), values=np.array([0], dtype=np.int64), dense_shape=np.array([1])) st_1 = sparse_tensor.SparseTensorValue( indices=np.array([[0, 0], [1, 1]]), values=np.array([-1., 1.], dtype=np.float32), dense_shape=np.array([2, 2])) opt = optional_ops.Optional.from_value((st_0, st_1)) self.assertTrue(self.evaluate(opt.has_value())) val_0, val_1 = opt.get_value() for expected, actual in [(st_0, val_0), (st_1, val_1)]: self.assertAllEqual(expected.indices, self.evaluate(actual.indices)) self.assertAllEqual(expected.values, self.evaluate(actual.values)) self.assertAllEqual(expected.dense_shape, self.evaluate(actual.dense_shape)) @test_util.run_in_graph_and_eager_modes def testFromNone(self): value_structure = structure.TensorStructure(dtypes.float32, []) opt = optional_ops.Optional.none_from_structure(value_structure) self.assertTrue(opt.value_structure.is_compatible_with(value_structure)) self.assertFalse( opt.value_structure.is_compatible_with( structure.TensorStructure(dtypes.float32, [1]))) self.assertFalse( opt.value_structure.is_compatible_with( structure.TensorStructure(dtypes.int32, []))) self.assertFalse(self.evaluate(opt.has_value())) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(opt.get_value()) @test_util.run_in_graph_and_eager_modes def testCopyToGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/cpu:0"): optional_with_value = optional_ops.Optional.from_value( (constant_op.constant(37.0), constant_op.constant("Foo"), constant_op.constant(42))) optional_none = optional_ops.Optional.none_from_structure( structure.TensorStructure(dtypes.float32, [])) with ops.device("/gpu:0"): gpu_optional_with_value = optional_ops._OptionalImpl( array_ops.identity(optional_with_value._variant_tensor), optional_with_value.value_structure) gpu_optional_none = optional_ops._OptionalImpl( array_ops.identity(optional_none._variant_tensor), optional_none.value_structure) gpu_optional_with_value_has_value = gpu_optional_with_value.has_value() gpu_optional_with_value_values = gpu_optional_with_value.get_value() gpu_optional_none_has_value = gpu_optional_none.has_value() self.assertTrue(self.evaluate(gpu_optional_with_value_has_value)) self.assertEqual((37.0, b"Foo", 42), self.evaluate(gpu_optional_with_value_values)) self.assertFalse(self.evaluate(gpu_optional_none_has_value)) def _assertElementValueEqual(self, expected, actual): if isinstance(expected, dict): self.assertItemsEqual(list(expected.keys()), list(actual.keys())) for k in expected.keys(): self._assertElementValueEqual(expected[k], actual[k]) elif isinstance(expected, sparse_tensor.SparseTensorValue): self.assertAllEqual(expected.indices, actual.indices) self.assertAllEqual(expected.values, actual.values) self.assertAllEqual(expected.dense_shape, actual.dense_shape) else: self.assertAllEqual(expected, actual) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure(dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))}, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, []))})), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testOptionalStructure(self, tf_value_fn, expected_value_structure): tf_value = tf_value_fn() opt = optional_ops.Optional.from_value(tf_value) self.assertTrue( expected_value_structure.is_compatible_with(opt.value_structure)) self.assertTrue( opt.value_structure.is_compatible_with(expected_value_structure)) opt_structure = structure.Structure.from_value(opt) self.assertIsInstance(opt_structure, optional_ops.OptionalStructure) self.assertTrue(opt_structure.is_compatible_with(opt_structure)) self.assertTrue(opt_structure._value_structure.is_compatible_with( expected_value_structure)) self.assertEqual([dtypes.variant], opt_structure._flat_types) self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes) # All OptionalStructure objects are not compatible with a non-optional # value. non_optional_structure = structure.Structure.from_value( constant_op.constant(42.0)) self.assertFalse(opt_structure.is_compatible_with(non_optional_structure)) # Assert that the optional survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_opt = opt_structure._from_tensor_list( opt_structure._to_tensor_list(opt)) if isinstance(tf_value, optional_ops.Optional): self.assertEqual( self.evaluate(tf_value.get_value()), self.evaluate(round_trip_opt.get_value().get_value())) else: self.assertEqual( self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value())) @parameterized.named_parameters( ("Tensor", np.array([1, 2, 3], dtype=np.int32), lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True), ("SparseTensor", sparse_tensor.SparseTensorValue( indices=[[0, 0], [1, 1]], values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]), lambda: sparse_tensor.SparseTensor( indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]), False), ("Nest", {"a": np.array([1, 2, 3], dtype=np.int32), "b": sparse_tensor.SparseTensorValue( indices=[[0, 0], [1, 1]], values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2])}, lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32), "b": sparse_tensor.SparseTensor( indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2])}, False), ) def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu): if not works_on_gpu and test.is_gpu_available(): self.skipTest("Test case not yet supported on GPU.") ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3) iterator = ds.make_initializable_iterator() next_elem = iterator_ops.get_next_as_optional(iterator) self.assertIsInstance(next_elem, optional_ops.Optional) self.assertTrue( next_elem.value_structure.is_compatible_with( structure.Structure.from_value(tf_value_fn()))) elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() with self.cached_session() as sess: # Before initializing the iterator, evaluating the optional fails with # a FailedPreconditionError. with self.assertRaises(errors.FailedPreconditionError): sess.run(elem_has_value_t) with self.assertRaises(errors.FailedPreconditionError): sess.run(elem_value_t) # For each element of the dataset, assert that the optional evaluates to # the expected value. sess.run(iterator.initializer) for _ in range(3): elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t]) self.assertTrue(elem_has_value) self._assertElementValueEqual(np_value, elem_value) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. for _ in range(2): self.assertFalse(sess.run(elem_has_value_t)) with self.assertRaises(errors.InvalidArgumentError): sess.run(elem_value_t)
class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) with self.cached_session() as sess: graph = graph_pb2.GraphDef().FromString( sess.run(dataset._as_serialized_graph())) self.assertTrue( any(node.op != "RangeDataset" for node in graph.node)) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave(lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", lambda: readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", lambda: dataset_ops.Dataset.from_generator( DatasetOpsTest.make_gen(), dtypes.int32), 1), ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])), ("Range", lambda: dataset_ops.Dataset.range(10)), ("TextLine", lambda: readers.TextLineDataset("")), ("TFRecord", lambda: readers.TFRecordDataset(""), 1), ) def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0): self.assertEqual(num_inputs, len(dataset_fn()._inputs())) def testDatasetComplexSourceInputs(self): dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1]))) self.assertEqual(0, len(dataset_fn._inputs())) @parameterized.named_parameters( ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), lambda: dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), lambda: dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), lambda: dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), lambda: dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), lambda: dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), lambda: dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn): input_dataset = input_dataset_fn() self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testUnaryTransformationInputsApply(self): input_dataset = dataset_ops.Dataset.range(0) dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0)) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2 ], lambda: dataset_ops.Dataset.range(0)), ("Interleave", [lambda: dataset_ops.Dataset.range(0), None ], lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args, input_dataset_fn): input_dataset = input_dataset_fn() dataset_fn = self.make_interleave_fn(*interleave_fn_args) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), lambda: dataset_ops.Dataset.range(0), lambda: dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn): input1 = input1_fn() input2 = input2_fn() self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))), ) def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual(nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) def testOptionsDefault(self): ds = dataset_ops.Dataset.range(0) self.assertEqual(dataset_ops.Options(), ds.options()) def testOptionsOnce(self): options = dataset_ops.Options() ds = dataset_ops.Dataset.range(0).with_options(options).cache() self.assertEqual(options, ds.options()) def testOptionsTwiceSame(self): options = dataset_ops.Options() options.experimental_autotune = True ds = dataset_ops.Dataset.range(0).with_options(options).with_options( options) self.assertEqual(options, ds.options()) def testOptionsTwiceDifferent(self): options1 = dataset_ops.Options() options1.experimental_autotune = True options2 = dataset_ops.Options() options2.experimental_filter_fusion = False ds = dataset_ops.Dataset.range(0).with_options(options1).with_options( options2) self.assertTrue(ds.options().experimental_autotune) # Explicitly check that flag is False since assertFalse allows None self.assertIs(ds.options().experimental_filter_fusion, False) def testOptionsTwiceDifferentError(self): options1 = dataset_ops.Options() options1.experimental_autotune = True options2 = dataset_ops.Options() options2.experimental_autotune = False with self.assertRaisesRegexp(ValueError, "Cannot merge incompatible values"): dataset_ops.Dataset.range(0).with_options(options1).with_options( options2) def testOptionsMergeOptionsFromMultipleInputs(self): options1 = dataset_ops.Options() options1.experimental_autotune = True options2 = dataset_ops.Options() options2.experimental_filter_fusion = True ds = dataset_ops.Dataset.zip( (dataset_ops.Dataset.range(0).with_options(options1), dataset_ops.Dataset.range(0).with_options(options2))) self.assertTrue(ds.options().experimental_autotune) self.assertTrue(ds.options().experimental_filter_fusion) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure( dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, [])) })), ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetStructure( structure.TensorStructure(dtypes.int32, []))), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testDatasetStructure(self, tf_value_fn, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map( lambda _: tf_value_fn()) dataset_structure = structure.Structure.from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure) # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing # the element structure. self.assertTrue( expected_element_structure.is_compatible_with( dataset_structure._element_structure)) self.assertTrue( dataset_structure._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual([dtypes.variant], dataset_structure._flat_types) self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value_fn() if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces(round_trip_dataset, [self.evaluate(tf_value_fn())], requires_initialization=True)