示例#1
0
class StructureTest(test.TestCase, parameterized.TestCase):
    # pylint disable=protected-access

    @parameterized.parameters(
        (constant_op.constant(37.0), structure.TensorStructure,
         [dtypes.float32], [[]]),
        (sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         structure.SparseTensorStructure, [dtypes.variant], [[3]]),
        ((constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
         structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
        ({
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]
                                                                       ]),
        ({
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(
                indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(
                      indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, structure.NestedStructure,
         [dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]]))
    def testFlatStructure(self, value, expected_structure, expected_types,
                          expected_shapes):
        s = structure.Structure.from_value(value)
        self.assertIsInstance(s, expected_structure)
        self.assertEqual(expected_types, s._flat_types)
        self.assertEqual(expected_shapes, s._flat_shapes)

    @parameterized.parameters(
        (constant_op.constant(37.0), [
            constant_op.constant(38.0),
            array_ops.placeholder(dtypes.float32),
            variables.Variable(100.0), 42.0,
            np.array(42.0, dtype=np.float32)
        ], [constant_op.constant([1.0, 2.0]),
            constant_op.constant(37)]),
        (sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), [
                sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]],
                                           values=[10, -1],
                                           dense_shape=[4, 5]),
                sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]],
                                                values=[10, -1],
                                                dense_shape=[4, 5]),
                array_ops.sparse_placeholder(dtype=dtypes.int32),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None])
            ], [
                constant_op.constant(37, shape=[4, 5]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None, None]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
            ]),
        ({
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6])
        }], [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6, 7])
        }, {
            "a": constant_op.constant(15),
            "b": constant_op.constant([4, 5, 6])
        }, {
            "a":
            constant_op.constant(15),
            "b":
            sparse_tensor.SparseTensor(
                indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
        }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
    )
    def testIsCompatibleWithStructure(self, original_value, compatible_values,
                                      incompatible_values):
        s = structure.Structure.from_value(original_value)
        for compatible_value in compatible_values:
            self.assertTrue(
                s.is_compatible_with(
                    structure.Structure.from_value(compatible_value)))
        for incompatible_value in incompatible_values:
            self.assertFalse(
                s.is_compatible_with(
                    structure.Structure.from_value(incompatible_value)))

    # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
    # will be executed before the (eager- or graph-mode) test environment has been
    # set up.
    # pylint: disable=g-long-lambda
    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), ),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, ),
        (lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, ),
    )
    def testRoundTripConversion(self, value_fn):
        value = value_fn()
        s = structure.Structure.from_value(value)
        before = self.evaluate(value)
        after = self.evaluate(s._from_tensor_list(s._to_tensor_list(value)))

        flat_before = nest.flatten(before)
        flat_after = nest.flatten(after)
        for b, a in zip(flat_before, flat_after):
            if isinstance(b, sparse_tensor.SparseTensorValue):
                self.assertAllEqual(b.indices, a.indices)
                self.assertAllEqual(b.values, a.values)
                self.assertAllEqual(b.dense_shape, a.dense_shape)
            else:
                self.assertAllEqual(b, a)

    # pylint: enable=g-long-lambda

    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.Structure.from_value(value_tensor)
        flat_tensor = s_tensor._to_tensor_list(value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor)
        flat_sparse_tensor = s_sparse_tensor._to_tensor_list(
            value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.Structure.from_value(value_nest)
        flat_nest = s_nest._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            s_tensor._to_tensor_list(value_sparse_tensor)
        with self.assertRaisesRegexp(
                ValueError,
                r"Value \{.*\} is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            s_tensor._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(TypeError,
                                     "Input must be a SparseTensor"):
            s_sparse_tensor._to_tensor_list(value_tensor)

        with self.assertRaisesRegexp(TypeError,
                                     "Input must be a SparseTensor"):
            s_sparse_tensor._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.* not compatible with the nested structure "
                ".*TensorStructure.*TensorStructure"):
            s_nest._to_tensor_list(value_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.* not compatible with the nested structure "
                ".*TensorStructure.*TensorStructure"):
            s_nest._to_tensor_list(value_sparse_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                r"Cannot convert.*with dtype.*float32.* and shape \(\)"):
            s_tensor._from_tensor_list(flat_sparse_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "TensorStructure corresponds to a single tf.Tensor."):
            s_tensor._from_tensor_list(flat_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_sparse_tensor._from_tensor_list(flat_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_sparse_tensor._from_tensor_list(flat_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 1."):
            s_nest._from_tensor_list(flat_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 1."):
            s_nest._from_tensor_list(flat_sparse_tensor)

    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.Structure.from_value(value_0)
        flat_s_0 = s_0._to_tensor_list(value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.Structure.from_value(value_1)
        flat_s_1 = s_1._to_tensor_list(value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.Structure.from_value(value_2)
        flat_s_2 = s_2._to_tensor_list(value_2)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.* not compatible with the nested structure "
                ".*TensorStructure"):
            s_0._to_tensor_list(value_1)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.*SparseTensor.* not compatible with the "
                "nested structure .*TensorStructure"):
            s_0._to_tensor_list(value_2)

        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.* not compatible with the nested structure "
                ".*SparseTensorStructure"):
            s_1._to_tensor_list(value_0)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.*SparseTensor.* not compatible with the "
                "nested structure .*TensorStructure"):
            s_0._to_tensor_list(value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.*Tensor.* not compatible with the nested structure "
                ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
                "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"
        ):
            s_2._to_tensor_list(value_0)

        with self.assertRaisesRegexp(
                ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* "
                "not compatible with the nested structure .*"
                "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
                "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"
        ):
            s_2._to_tensor_list(value_1)

        with self.assertRaisesRegexp(
                ValueError,
                r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"):
            s_0._from_tensor_list(flat_s_1)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 3."):
            s_0._from_tensor_list(flat_s_2)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_1._from_tensor_list(flat_s_0)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 3."):
            s_1._from_tensor_list(flat_s_2)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 3 flat values in NestedStructure but got 2."):
            s_2._from_tensor_list(flat_s_0)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 3 flat values in NestedStructure but got 2."):
            s_2._from_tensor_list(flat_s_1)

    @parameterized.named_parameters(
        ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", dtypes.int32, tensor_shape.matrix(
            2, 2), sparse_tensor.SparseTensor,
         structure.SparseTensorStructure(dtypes.int32, [2, 2])),
        ("Nest", {
            "a": dtypes.float32,
            "b": (dtypes.int32, dtypes.string)
        }, {
            "a": tensor_shape.scalar(),
            "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())
        }, {
            "a": ops.Tensor,
            "b": (sparse_tensor.SparseTensor, ops.Tensor)
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                   structure.TensorStructure(dtypes.string, []))
         })),
    )
    def testFromLegacyStructure(self, output_types, output_shapes,
                                output_classes, expected_structure):
        actual_structure = structure.Structure._from_legacy_structure(
            output_types, output_shapes, output_classes)
        self.assertTrue(
            expected_structure.is_compatible_with(actual_structure))
        self.assertTrue(
            actual_structure.is_compatible_with(expected_structure))
 def _element_structure(self):
     return structure.NestedStructure(
         tuple([structure.TensorStructure(dtypes.string, [])] *
               self._num_outputs))
 def _element_structure(self):
     return structure.NestedStructure(
         (structure.TensorStructure(dtypes.string, []),
          structure.TensorStructure(dtypes.string, [])))
示例#4
0
class IteratorTest(test.TestCase, parameterized.TestCase):

  @test_util.run_deprecated_v1
  def testNoGradients(self):
    component = constant_op.constant([1.])
    side = constant_op.constant(0.)
    add = lambda x: x + side
    dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add)
    value = dataset_ops.make_one_shot_iterator(dataset).get_next()
    self.assertIsNone(gradients_impl.gradients(value, component)[0])
    self.assertIsNone(gradients_impl.gradients(value, side)[0])
    self.assertIsNone(gradients_impl.gradients(value, [component, side])[0])

  @test_util.run_deprecated_v1
  def testCapturingStateInOneShotRaisesException(self):
    var = variables.Variable(37.0, name="myvar")
    dataset = (
        dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0])
        .map(lambda x: x + var))
    with self.assertRaisesRegexp(
        ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support "
        "datasets that capture stateful objects.+myvar"):
      dataset_ops.make_one_shot_iterator(dataset)

  @test_util.run_deprecated_v1
  def testOneShotIterator(self):
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))

    def _map_fn(x, y, z):
      return math_ops.square(x), math_ops.square(y), math_ops.square(z)

    iterator = dataset_ops.make_one_shot_iterator(
        dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
        .repeat(14))
    get_next = iterator.get_next()

    self.assertEqual([c.shape[1:] for c in components],
                     [t.shape for t in get_next])

    with self.cached_session() as sess:
      for _ in range(14):
        for i in range(7):
          result = sess.run(get_next)
          for component, result_component in zip(components, result):
            self.assertAllEqual(component[i]**2, result_component)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  @test_util.run_deprecated_v1
  def testOneShotIteratorCaptureByValue(self):
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))
    tensor_components = tuple([ops.convert_to_tensor(c) for c in components])

    def _map_fn(x, y, z):
      return math_ops.square(x), math_ops.square(y), math_ops.square(z)

    iterator = dataset_ops.make_one_shot_iterator(
        dataset_ops.Dataset.from_tensor_slices(tensor_components)
        .map(_map_fn).repeat(14))
    get_next = iterator.get_next()

    self.assertEqual([c.shape[1:] for c in components],
                     [t.shape for t in get_next])

    with self.cached_session() as sess:
      for _ in range(14):
        for i in range(7):
          result = sess.run(get_next)
          for component, result_component in zip(components, result):
            self.assertAllEqual(component[i]**2, result_component)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  def testOneShotIteratorInsideContainer(self):
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))

    def within_container():

      def _map_fn(x, y, z):
        return math_ops.square(x), math_ops.square(y), math_ops.square(z)

      iterator = dataset_ops.make_one_shot_iterator(
          dataset_ops.Dataset.from_tensor_slices(components)
          .map(_map_fn).repeat(14))
      return iterator.get_next()

    server = server_lib.Server.create_local_server()

    # Create two iterators within unique containers, and run them to
    # make sure that the resources aren't shared.
    #
    # The test below would fail if cname were the same across both
    # sessions.
    for j in range(2):
      with session.Session(server.target) as sess:
        cname = "iteration%d" % j
        with ops.container(cname):
          get_next = within_container()

        for _ in range(14):
          for i in range(7):
            result = sess.run(get_next)
            for component, result_component in zip(components, result):
              self.assertAllEqual(component[i]**2, result_component)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)

  @test_util.run_deprecated_v1
  def testOneShotIteratorNonBlocking(self):
    dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x)
    iterator = dataset_ops.make_one_shot_iterator(dataset)
    next_element = iterator.get_next()

    # Create a session with a single thread to ensure that the
    # one-shot iterator initializer does not deadlock.
    config = config_pb2.ConfigProto(
        inter_op_parallelism_threads=1, use_per_session_threads=True)
    with session.Session(config=config) as sess:
      self.assertAllEqual([1, 4, 9], sess.run(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)

    # Test with multiple threads invoking the one-shot iterator concurrently.
    with session.Session(config=config) as sess:
      results = []

      def consumer_thread():
        try:
          results.append(sess.run(next_element))
        except errors.OutOfRangeError:
          results.append(None)

      num_threads = 8
      threads = [
          self.checkedThread(consumer_thread) for _ in range(num_threads)
      ]
      for t in threads:
        t.start()
      for t in threads:
        t.join()

      self.assertEqual(num_threads, len(results))
      self.assertEqual(num_threads - 1,
                       len([None for r in results if r is None]))
      self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None])

  @test_util.run_deprecated_v1
  def testOneShotIteratorInitializerFails(self):
    # Define a dataset whose initialization will always fail.
    dataset = dataset_ops.Dataset.from_tensors(
        array_ops.check_numerics(
            constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
    iterator = dataset_ops.make_one_shot_iterator(dataset)
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
        sess.run(next_element)

      # Test that subsequent attempts to use the iterator also fail.
      with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
        sess.run(next_element)

    with self.cached_session() as sess:

      def consumer_thread():
        with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
          sess.run(next_element)

      num_threads = 8
      threads = [
          self.checkedThread(consumer_thread) for _ in range(num_threads)
      ]
      for t in threads:
        t.start()
      for t in threads:
        t.join()

  def testSimpleSharedResource(self):
    components = (np.array(1, dtype=np.int64),
                  np.array([1, 2, 3], dtype=np.int64),
                  np.array(37.0, dtype=np.float64))

    server = server_lib.Server.create_local_server()

    # Create two non-overlapping sessions that share the same iterator
    # resource on the same server, and verify that an action of the
    # first session (initializing the iterator) is visible in the
    # second session.
    with ops.Graph().as_default():
      iterator = (
          dataset_ops.Dataset.from_tensors(components)
          .map(lambda x, y, z: (x, y, z)).make_initializable_iterator(
              shared_name="shared_iterator"))
      init_op = iterator.initializer
      get_next = iterator.get_next()

      with session.Session(server.target) as sess:
        sess.run(init_op)
        results = sess.run(get_next)
        for component, result_component in zip(components, results):
          self.assertAllEqual(component, result_component)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)

        # Re-initialize the iterator in the first session.
        sess.run(init_op)

    with ops.Graph().as_default():
      # Re-define the iterator manually, without defining any of the
      # functions in this graph, to ensure that we are not
      # accidentally redefining functions with the same names in the
      # new graph.
      iterator = iterator_ops.Iterator.from_structure(
          shared_name="shared_iterator",
          output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
          output_shapes=([], [3], []))
      get_next = iterator.get_next()

      with session.Session(server.target) as sess:
        # Use the iterator without re-initializing in the second session.
        results = sess.run(get_next)
        for component, result_component in zip(components, results):
          self.assertAllEqual(component, result_component)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)

  @test_util.run_deprecated_v1
  def testNotInitializedError(self):
    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
    iterator = dataset_ops.make_initializable_iterator(
        dataset_ops.Dataset.from_tensors(components))
    get_next = iterator.get_next()

    with self.cached_session() as sess:
      with self.assertRaisesRegexp(errors.FailedPreconditionError,
                                   "iterator has not been initialized"):
        sess.run(get_next)

  @test_util.run_deprecated_v1
  def testReinitializableIterator(self):
    dataset_3 = dataset_ops.Dataset.from_tensors(
        constant_op.constant([1, 2, 3]))
    dataset_4 = dataset_ops.Dataset.from_tensors(
        constant_op.constant([4, 5, 6, 7]))
    iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types,
                                                    [None])

    dataset_3_init_op = iterator.make_initializer(dataset_3)
    dataset_4_init_op = iterator.make_initializer(dataset_4)
    get_next = iterator.get_next()

    self.assertEqual(dataset_3.output_types, iterator.output_types)
    self.assertEqual(dataset_4.output_types, iterator.output_types)
    self.assertEqual([None], iterator.output_shapes.as_list())

    with self.cached_session() as sess:
      # The iterator is initially uninitialized.
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(get_next)

      # Initialize with one dataset.
      sess.run(dataset_3_init_op)
      self.assertAllEqual([1, 2, 3], sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

      # Initialize with a different dataset.
      sess.run(dataset_4_init_op)
      self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

      # Reinitialize with the first dataset.
      sess.run(dataset_3_init_op)
      self.assertAllEqual([1, 2, 3], sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  @test_util.run_deprecated_v1
  def testReinitializableIteratorWithFunctions(self):

    def g():
      for i in range(10):
        yield i

    iterator = iterator_ops.Iterator.from_structure(dtypes.int64, [])
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      dataset_1 = dataset_ops.Dataset.from_generator(
          g, output_types=dtypes.int64)
      sess.run(iterator.make_initializer(dataset_1))
      for expected in range(10):
        self.assertEqual(expected, sess.run(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)

      dataset_2 = dataset_ops.Dataset.from_generator(
          g, output_types=dtypes.int64)
      sess.run(iterator.make_initializer(dataset_2))
      for expected in range(10):
        self.assertEqual(expected, sess.run(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)

  def testReinitializableIteratorStaticErrors(self):
    # Non-matching structure for types and shapes.
    with self.assertRaises(TypeError):
      iterator = iterator_ops.Iterator.from_structure(
          (dtypes.int64, dtypes.float64), [None])

    # Test validation of dataset argument.
    iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
                                                     dtypes.float64))

    # Incompatible structure.
    with self.assertRaises(ValueError):
      iterator.make_initializer(
          dataset_ops.Dataset.from_tensors(((constant_op.constant(
              [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant(
                  [4., 5., 6., 7.], dtype=dtypes.float64),))))

    # Incompatible types.
    with self.assertRaises(TypeError):
      iterator.make_initializer(
          dataset_ops.Dataset.from_tensors(
              (constant_op.constant([1, 2, 3], dtype=dtypes.int32),
               constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32))))

    # Incompatible shapes.
    iterator = iterator_ops.Iterator.from_structure(
        (dtypes.int64, dtypes.float64), ([None], []))
    with self.assertRaises(TypeError):
      iterator.make_initializer(
          dataset_ops.Dataset.from_tensors(
              (constant_op.constant([1, 2, 3], dtype=dtypes.int64),
               constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64))))

  @test_util.run_deprecated_v1
  def testIteratorStringHandle(self):
    dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

    iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
    iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    feedable_iterator = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
    next_element = feedable_iterator.get_next()

    self.assertEqual(dataset_3.output_types, feedable_iterator.output_types)
    self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
    self.assertEqual([], feedable_iterator.output_shapes)

    with self.cached_session() as sess:
      iterator_3_handle = sess.run(iterator_3.string_handle())
      iterator_4_handle = sess.run(iterator_4.string_handle())

      self.assertEqual(10,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(1,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(20,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(2,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(30,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(3,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(40,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_3_handle})
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_4_handle})

  @test_util.run_deprecated_v1
  def testIteratorStringHandleFuture(self):
    with forward_compat.forward_compatibility_horizon(2018, 8, 4):
      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
      dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

      iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
      iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

      handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
      feedable_iterator = iterator_ops.Iterator.from_string_handle(
          handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
      next_element = feedable_iterator.get_next()

      self.assertEqual(dataset_3.output_types, feedable_iterator.output_types)
      self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
      self.assertEqual([], feedable_iterator.output_shapes)

      with self.cached_session() as sess:
        iterator_3_handle = sess.run(iterator_3.string_handle())
        iterator_4_handle = sess.run(iterator_4.string_handle())

        self.assertEqual(
            10,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_4_handle}))
        self.assertEqual(
            1,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_3_handle}))
        self.assertEqual(
            20,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_4_handle}))
        self.assertEqual(
            2,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_3_handle}))
        self.assertEqual(
            30,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_4_handle}))
        self.assertEqual(
            3,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_3_handle}))
        self.assertEqual(
            40,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_4_handle}))
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(
              next_element, feed_dict={handle_placeholder: iterator_3_handle})
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(
              next_element, feed_dict={handle_placeholder: iterator_4_handle})

  @test_util.run_deprecated_v1
  def testIteratorStringHandleReuseTensorObject(self):
    dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
    initializable_iterator = dataset_ops.make_initializable_iterator(dataset)
    structure_iterator = iterator_ops.Iterator.from_structure(
        dataset.output_types)

    created_ops = len(ops.get_default_graph().get_operations())

    self.assertIs(one_shot_iterator.string_handle(),
                  one_shot_iterator.string_handle())
    self.assertIs(initializable_iterator.string_handle(),
                  initializable_iterator.string_handle())
    self.assertIs(structure_iterator.string_handle(),
                  structure_iterator.string_handle())

    # Assert that getting the (default) string handle creates no ops.
    self.assertEqual(created_ops, len(ops.get_default_graph().get_operations()))

    # Specifying an explicit name will create a new op.
    handle_with_name = one_shot_iterator.string_handle(name="foo")
    self.assertEqual("foo", handle_with_name.op.name)
    self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name)

    handle_with_same_name = one_shot_iterator.string_handle(name="foo")
    self.assertEqual("foo_1", handle_with_same_name.op.name)
    self.assertIsNot(handle_with_name, handle_with_same_name)

  @test_util.run_deprecated_v1
  def testIteratorStringHandleError(self):
    dataset_int_scalar = (
        dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat())
    dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]))

    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])

    feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dtypes.int32, [])
    feedable_int_vector = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dtypes.int32, [None])
    feedable_int_any = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dtypes.int32)

    with self.cached_session() as sess:
      handle_int_scalar = sess.run(dataset_ops.make_one_shot_iterator(
          dataset_int_scalar).string_handle())
      handle_float_vector = sess.run(dataset_ops.make_one_shot_iterator(
          dataset_float_vector).string_handle())

      self.assertEqual(1,
                       sess.run(
                           feedable_int_scalar.get_next(),
                           feed_dict={handle_placeholder: handle_int_scalar}))

      self.assertEqual(2,
                       sess.run(
                           feedable_int_any.get_next(),
                           feed_dict={handle_placeholder: handle_int_scalar}))

      with self.assertRaises(errors.InvalidArgumentError):
        print(sess.run(
            feedable_int_vector.get_next(),
            feed_dict={handle_placeholder: handle_int_scalar}))

      with self.assertRaises(errors.InvalidArgumentError):
        print(sess.run(
            feedable_int_vector.get_next(),
            feed_dict={handle_placeholder: handle_float_vector}))

  @test_util.run_deprecated_v1
  def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
    worker_config = config_pb2.ConfigProto()
    worker_config.device_count["CPU"] = 3

    with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
      iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
      iterator_3_handle = iterator_3.string_handle()

    @function.Defun(dtypes.string)
    def _remote_fn(h):
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          h, dataset_3.output_types, dataset_3.output_shapes)
      return remote_iterator.get_next()

    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
      target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
      remote_op = functional_ops.remote_call(
          args=[iterator_3_handle],
          Tout=[dtypes.int32],
          f=_remote_fn,
          target=target_placeholder)

    with self.session(config=worker_config) as sess:
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
          })
      self.assertEqual(elem, [1])
      # Fails when target is cpu:2 where the resource is not located.
      with self.assertRaises(errors.InvalidArgumentError):
        sess.run(
            remote_op,
            feed_dict={
                target_placeholder: "/job:localhost/replica:0/task:0/cpu:2"
            })
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
          })
      self.assertEqual(elem, [2])
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
          })
      self.assertEqual(elem, [3])
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            remote_op,
            feed_dict={
                target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
            })

  @test_util.run_deprecated_v1
  def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
    s1 = server_lib.Server.create_local_server()
    s2 = server_lib.Server.create_local_server()
    s3 = server_lib.Server.create_local_server()

    cluster_def = cluster_pb2.ClusterDef()
    workers = cluster_def.job.add()
    workers.name = "worker"
    workers.tasks[0] = s1.target[len("grpc://"):]
    workers.tasks[1] = s2.target[len("grpc://"):]
    client = cluster_def.job.add()
    client.name = "client"
    client.tasks[0] = s3.target[len("grpc://"):]
    config = config_pb2.ConfigProto(cluster_def=cluster_def)

    worker_devices = [
        "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2)
    ]
    itr_handles = []
    for device in worker_devices:
      with ops.device(device):
        src = dataset_ops.Dataset.from_tensor_slices([device])
        itr = dataset_ops.make_one_shot_iterator(src)
        itr_handles.append(itr.string_handle())

    targets = dataset_ops.Dataset.from_tensor_slices(worker_devices)
    handles = dataset_ops.Dataset.from_tensor_slices(itr_handles)

    @function.Defun(dtypes.string)
    def loading_func(h):
      remote_itr = iterator_ops.Iterator.from_string_handle(
          h, itr.output_types, itr.output_shapes)
      return remote_itr.get_next()

    def map_fn(target, handle):
      return functional_ops.remote_call(
          args=[handle], Tout=[dtypes.string], f=loading_func, target=target)

    with ops.device("/job:client"):
      client_dataset = dataset_ops.Dataset.zip((targets, handles)).map(map_fn)
      itr = dataset_ops.make_initializable_iterator(client_dataset)
      n = itr.get_next()

    with session.Session(s3.target, config=config) as sess:
      sess.run(itr.initializer)
      expected_values = worker_devices
      for expected in expected_values:
        self.assertEqual((compat.as_bytes(expected),), sess.run(n))

      with self.assertRaises(errors.OutOfRangeError):
        sess.run(n)

  def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
      iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
      iterator_3_handle = iterator_3.string_handle()

    def _encode_raw(byte_array):
      return bytes(bytearray(byte_array))

    @function.Defun(dtypes.uint8)
    def _remote_fn(h):
      handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, dataset_3.output_types, dataset_3.output_shapes)
      return remote_iterator.get_next()

    with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
      target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
      iterator_3_handle_uint8 = parsing_ops.decode_raw(
          bytes=iterator_3_handle, out_type=dtypes.uint8)
      remote_op = functional_ops.remote_call(
          args=[iterator_3_handle_uint8],
          Tout=[dtypes.int32],
          f=_remote_fn,
          target=target_placeholder)

    with self.cached_session() as sess:
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
          })
      self.assertEqual(elem, [1])
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
          })
      self.assertEqual(elem, [2])
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
          })
      self.assertEqual(elem, [3])
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            remote_op,
            feed_dict={
                target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
            })

  @test_util.run_deprecated_v1
  def testIncorrectIteratorRestore(self):

    def _path():
      return os.path.join(self.get_temp_dir(), "iterator")

    def _save_op(iterator_resource):
      iterator_state_variant = gen_dataset_ops.serialize_iterator(
          iterator_resource)
      save_op = io_ops.write_file(
          _path(), parsing_ops.serialize_tensor(iterator_state_variant))
      return save_op

    def _restore_op(iterator_resource):
      iterator_state_variant = parsing_ops.parse_tensor(
          io_ops.read_file(_path()), dtypes.variant)
      restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
                                                        iterator_state_variant)
      return restore_op

    def _build_range_dataset_graph():
      start = 1
      stop = 10
      iterator = dataset_ops.make_initializable_iterator(
          dataset_ops.Dataset.range(start, stop))
      init_op = iterator.initializer
      get_next = iterator.get_next()
      save_op = _save_op(iterator._iterator_resource)
      restore_op = _restore_op(iterator._iterator_resource)
      return init_op, get_next, save_op, restore_op

    def _build_reader_dataset_graph():
      filenames = ["test"]  # Does not exist but we don't care in this test.
      iterator = dataset_ops.make_initializable_iterator(
          readers.FixedLengthRecordDataset(filenames, 1, 0, 0))
      init_op = iterator.initializer
      get_next_op = iterator.get_next()
      save_op = _save_op(iterator._iterator_resource)
      restore_op = _restore_op(iterator._iterator_resource)
      return init_op, get_next_op, save_op, restore_op

    # Saving iterator for RangeDataset graph.
    with ops.Graph().as_default() as g:
      init_op, _, save_op, _ = _build_range_dataset_graph()
      with self.session(graph=g) as sess:
        sess.run(init_op)
        sess.run(save_op)

    # Attempt to restore the saved iterator into an IteratorResource of
    # incompatible type. An iterator of RangeDataset has output type int64,
    # while an iterator of FixedLengthRecordDataset has output type string.
    # So an InvalidArgumentError should be raised by
    # IteratorResource::set_iterator.
    with ops.Graph().as_default() as g:
      _, _, _, restore_op = _build_reader_dataset_graph()
      with self.session(graph=g) as sess:
        with self.assertRaises(errors.InvalidArgumentError):
          sess.run(restore_op)

  @test_util.run_deprecated_v1
  def testRepeatedGetNextWarning(self):
    iterator = dataset_ops.make_one_shot_iterator(dataset_ops.Dataset.range(10))
    warnings.simplefilter("always")
    with warnings.catch_warnings(record=True) as w:
      for _ in range(100):
        iterator.get_next()
    self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, len(w))
    for warning in w:
      self.assertIn(
          iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE, str(warning.message))

  def testEagerIteratorAsync(self):
    with context.eager_mode(), context.execution_mode(context.ASYNC):
      val = 0
      dataset = dataset_ops.Dataset.range(10)
      for foo in dataset:
        self.assertEqual(val, foo.numpy())
        val += 1

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       structure.TensorStructure(dtypes.float32, []),
       ops.Tensor, dtypes.float32, []),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[1]),
       structure.SparseTensorStructure(dtypes.int32, [1]),
       sparse_tensor.SparseTensor, dtypes.int32, [1]),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.TensorStructure(dtypes.string, [1]),
                 structure.TensorStructure(dtypes.string, []))}),
       {"a": ops.Tensor, "b": (ops.Tensor, ops.Tensor)},
       {"a": dtypes.float32, "b": (dtypes.string, dtypes.string)},
       {"a": [], "b": ([1], [])}),
  )
  def testIteratorStructure(self, tf_value_fn, expected_element_structure,
                            expected_output_classes, expected_output_types,
                            expected_output_shapes):
    tf_value = tf_value_fn()
    iterator = dataset_ops.make_one_shot_iterator(
        dataset_ops.Dataset.from_tensors(tf_value))

    self.assertTrue(expected_element_structure.is_compatible_with(
        iterator._element_structure))
    self.assertTrue(iterator._element_structure.is_compatible_with(
        expected_element_structure))

    self.assertEqual(expected_output_classes, iterator.output_classes)
    self.assertEqual(expected_output_types, iterator.output_types)
    self.assertEqual(expected_output_shapes, iterator.output_shapes)

  def testIteratorGetNextName(self):
    with ops.Graph().as_default():
      iterator = dataset_ops.make_one_shot_iterator(
          dataset_ops.Dataset.from_tensors(37.0))
      next_element = iterator.get_next(name="overridden_name")
      self.assertEqual("overridden_name", next_element.op.name)
示例#5
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        graph = graph_pb2.GraphDef().FromString(
            self.evaluate(dataset._as_serialized_graph()))
        self.assertTrue(any([node.op != "RangeDataset"
                             for node in graph.node]))

    @staticmethod
    def make_apply_fn(dataset):
        def apply_fn(dataset):
            def _apply_fn(dataset):
                return dataset.cache()

            return dataset.apply(_apply_fn)

        return apply_fn

    @staticmethod
    def make_gen():
        def gen():
            yield 42

        return gen

    @staticmethod
    def make_interleave_fn(dataset, num_parallel_calls=None):
        def interleave_fn(dataset):
            return dataset.interleave(lambda x: dataset_ops.Dataset.range(0),
                                      cycle_length=2,
                                      num_parallel_calls=num_parallel_calls)

        return interleave_fn

    @parameterized.named_parameters(
        ("FixedLengthRecord",
         lambda: readers.FixedLengthRecordDataset("", 42)),
        ("FromGenerator", lambda: dataset_ops.Dataset.from_generator(
            DatasetTest.make_gen(), dtypes.int32), 1),
        ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("Range", lambda: dataset_ops.Dataset.range(10)),
        ("TextLine", lambda: readers.TextLineDataset("")),
        ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
    )
    def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
        self.assertEqual(num_inputs, len(dataset_fn()._inputs()))

    def testDatasetComplexSourceInputs(self):
        dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
            sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                         [2, 0]]),
                                       values=np.array([0, 0, 0]),
                                       dense_shape=np.array([3, 1])))
        self.assertEqual(0, len(dataset_fn._inputs()))

    @parameterized.named_parameters(
        ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)),
        ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)),
        ("Filter", lambda x: x.filter(lambda x: True),
         lambda: dataset_ops.Dataset.range(0)),
        ("FlatMap",
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
         lambda: dataset_ops.Dataset.range(0)),
        ("Map", lambda x: x.map(lambda x: x),
         lambda: dataset_ops.Dataset.range(0)),
        ("PaddedBatch", lambda x: x.padded_batch(10, []),
         lambda: dataset_ops.Dataset.range(0)),
        ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
         lambda: dataset_ops.Dataset.range(0)),
        ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)),
        ("Shuffle", lambda x: x.shuffle(10),
         lambda: dataset_ops.Dataset.range(0)),
        ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)),
        ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)),
        ("Window", lambda x: x.window(10),
         lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
        input_dataset = input_dataset_fn()
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testUnaryTransformationInputsApply(self):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2
                                ], lambda: dataset_ops.Dataset.range(0)),
        ("Interleave", [lambda: dataset_ops.Dataset.range(0), None
                        ], lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args,
                                                      input_dataset_fn):
        input_dataset = input_dataset_fn()
        dataset_fn = self.make_interleave_fn(*interleave_fn_args)
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("Concatenate", lambda x, y: x.concatenate(y),
         lambda: dataset_ops.Dataset.range(0),
         lambda: dataset_ops.Dataset.range(1)))
    def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
        input1 = input1_fn()
        input2 = input2_fn()
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @parameterized.named_parameters(
        ("ZipOne", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0))),
        ("ZipNest", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0),
          (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
        ("ZipTuple", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))),
    )
    def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
        input_datasets = input_datasets_fn()
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), structure.SparseTensorStructure(
                dtypes.int32, [1])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.TensorStructure(dtypes.string, [1]),
                   structure.TensorStructure(dtypes.string, []))
         })),
        ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
            constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetStructure(
             structure.TensorStructure(dtypes.int32, []))),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testDatasetStructure(self, tf_value_fn, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(
            lambda _: tf_value_fn())
        dataset_structure = structure.Structure.from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

        # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
        # the element structure.
        self.assertTrue(
            expected_element_structure.is_compatible_with(
                dataset_structure._element_structure))
        self.assertTrue(
            dataset_structure._element_structure.is_compatible_with(
                expected_element_structure))

        self.assertEqual([dtypes.variant], dataset_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()],
                         dataset_structure._flat_shapes)

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value_fn()

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value_fn())],
                                       requires_initialization=True)

    # NOTE: This test is specific to graph mode and is skipped in eager mode.
    @test_util.run_deprecated_v1
    def testSkipEagerSameGraphErrorOneShot(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    # NOTE: This test is specific to graph mode and is skipped in eager mode.
    @test_util.run_deprecated_v1
    def testSkipEagerSameGraphErrorOneShotSimple(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with test.mock.patch.object(logging, "warning") as mock_log:
                _ = dataset.make_one_shot_iterator()
                self.assertRegexpMatches(
                    str(mock_log.call_args),
                    "Please ensure that all datasets in the "
                    "pipeline are created in the same graph as the iterator.")

    # NOTE: This test is specific to graph mode and is skipped in eager mode.
    @test_util.run_deprecated_v1
    def testSkipEagerSameGraphErrorInitializable(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)
示例#6
0
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testFromValue(self):
        opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
        self.assertTrue(self.evaluate(opt.has_value()))
        self.assertEqual(37.0, self.evaluate(opt.get_value()))

    def testFromStructuredValue(self):
        opt = optional_ops.Optional.from_value({
            "a":
            constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        })
        self.assertTrue(self.evaluate(opt.has_value()))
        self.assertEqual({
            "a": 37.0,
            "b": ([b"Foo"], b"Bar")
        }, self.evaluate(opt.get_value()))

    def testFromSparseTensor(self):
        st_0 = sparse_tensor.SparseTensorValue(indices=np.array([[0]]),
                                               values=np.array([0],
                                                               dtype=np.int64),
                                               dense_shape=np.array([1]))
        st_1 = sparse_tensor.SparseTensorValue(
            indices=np.array([[0, 0], [1, 1]]),
            values=np.array([-1., 1.], dtype=np.float32),
            dense_shape=np.array([2, 2]))
        opt = optional_ops.Optional.from_value((st_0, st_1))
        self.assertTrue(self.evaluate(opt.has_value()))
        val_0, val_1 = opt.get_value()
        for expected, actual in [(st_0, val_0), (st_1, val_1)]:
            self.assertAllEqual(expected.indices,
                                self.evaluate(actual.indices))
            self.assertAllEqual(expected.values, self.evaluate(actual.values))
            self.assertAllEqual(expected.dense_shape,
                                self.evaluate(actual.dense_shape))

    def testFromNone(self):
        value_structure = structure.TensorStructure(dtypes.float32, [])
        opt = optional_ops.Optional.none_from_structure(value_structure)
        self.assertTrue(
            opt.value_structure.is_compatible_with(value_structure))
        self.assertFalse(
            opt.value_structure.is_compatible_with(
                structure.TensorStructure(dtypes.float32, [1])))
        self.assertFalse(
            opt.value_structure.is_compatible_with(
                structure.TensorStructure(dtypes.int32, [])))
        self.assertFalse(self.evaluate(opt.has_value()))
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(opt.get_value())

    def testAddN(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                # With value
                opt1 = optional_ops.Optional.from_value((1.0, 2.0))
                opt2 = optional_ops.Optional.from_value((3.0, 4.0))

                add_tensor = math_ops.add_n(
                    [opt1._variant_tensor, opt2._variant_tensor])
                add_opt = optional_ops._OptionalImpl(add_tensor,
                                                     opt1.value_structure)
                self.assertAllEqual(self.evaluate(add_opt.get_value()),
                                    (4.0, 6.0))

                # Without value
                opt_none1 = optional_ops.Optional.none_from_structure(
                    opt1.value_structure)
                opt_none2 = optional_ops.Optional.none_from_structure(
                    opt2.value_structure)
                add_tensor = math_ops.add_n(
                    [opt_none1._variant_tensor, opt_none2._variant_tensor])
                add_opt = optional_ops._OptionalImpl(add_tensor,
                                                     opt_none1.value_structure)
                self.assertFalse(self.evaluate(add_opt.has_value()))

    def testNestedAddN(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                opt1 = optional_ops.Optional.from_value([1, 2.0])
                opt2 = optional_ops.Optional.from_value([3, 4.0])
                opt3 = optional_ops.Optional.from_value(
                    (5.0, opt1._variant_tensor))
                opt4 = optional_ops.Optional.from_value(
                    (6.0, opt2._variant_tensor))

                add_tensor = math_ops.add_n(
                    [opt3._variant_tensor, opt4._variant_tensor])
                add_opt = optional_ops._OptionalImpl(add_tensor,
                                                     opt3.value_structure)
                self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0)

                inner_add_opt = optional_ops._OptionalImpl(
                    add_opt.get_value()[1], opt1.value_structure)
                self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])

    def testZerosLike(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                # With value
                opt = optional_ops.Optional.from_value((1.0, 2.0))
                zeros_tensor = array_ops.zeros_like(opt._variant_tensor)
                zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                                       opt.value_structure)
                self.assertAllEqual(self.evaluate(zeros_opt.get_value()),
                                    (0.0, 0.0))

                # Without value
                opt_none = optional_ops.Optional.none_from_structure(
                    opt.value_structure)
                zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor)
                zeros_opt = optional_ops._OptionalImpl(
                    zeros_tensor, opt_none.value_structure)
                self.assertFalse(self.evaluate(zeros_opt.has_value()))

    def testNestedZerosLike(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                opt1 = optional_ops.Optional.from_value(1.0)
                opt2 = optional_ops.Optional.from_value(opt1._variant_tensor)

                zeros_tensor = array_ops.zeros_like(opt2._variant_tensor)
                zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                                       opt2.value_structure)
                inner_zeros_opt = optional_ops._OptionalImpl(
                    zeros_opt.get_value(), opt1.value_structure)
                self.assertEqual(self.evaluate(inner_zeros_opt.get_value()),
                                 0.0)

    def testCopyToGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/cpu:0"):
            optional_with_value = optional_ops.Optional.from_value(
                (constant_op.constant(37.0), constant_op.constant("Foo"),
                 constant_op.constant(42)))
            optional_none = optional_ops.Optional.none_from_structure(
                structure.TensorStructure(dtypes.float32, []))

        with ops.device("/gpu:0"):
            gpu_optional_with_value = optional_ops._OptionalImpl(
                array_ops.identity(optional_with_value._variant_tensor),
                optional_with_value.value_structure)
            gpu_optional_none = optional_ops._OptionalImpl(
                array_ops.identity(optional_none._variant_tensor),
                optional_none.value_structure)

            gpu_optional_with_value_has_value = gpu_optional_with_value.has_value(
            )
            gpu_optional_with_value_values = gpu_optional_with_value.get_value(
            )

            gpu_optional_none_has_value = gpu_optional_none.has_value()

        self.assertTrue(self.evaluate(gpu_optional_with_value_has_value))
        self.assertEqual((37.0, b"Foo", 42),
                         self.evaluate(gpu_optional_with_value_values))
        self.assertFalse(self.evaluate(gpu_optional_none_has_value))

    def testNestedCopyToGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/cpu:0"):
            optional_with_value = optional_ops.Optional.from_value(
                (constant_op.constant(37.0), constant_op.constant("Foo"),
                 constant_op.constant(42)))
            optional_none = optional_ops.Optional.none_from_structure(
                structure.TensorStructure(dtypes.float32, []))
            nested_optional = optional_ops.Optional.from_value(
                (optional_with_value._variant_tensor,
                 optional_none._variant_tensor, 1.0))

        with ops.device("/gpu:0"):
            gpu_nested_optional = optional_ops._OptionalImpl(
                array_ops.identity(nested_optional._variant_tensor),
                nested_optional.value_structure)

            gpu_nested_optional_has_value = gpu_nested_optional.has_value()
            gpu_nested_optional_values = gpu_nested_optional.get_value()

        self.assertTrue(self.evaluate(gpu_nested_optional_has_value))

        inner_with_value = optional_ops._OptionalImpl(
            gpu_nested_optional_values[0], optional_with_value.value_structure)

        inner_none = optional_ops._OptionalImpl(gpu_nested_optional_values[1],
                                                optional_none.value_structure)

        self.assertEqual((37.0, b"Foo", 42),
                         self.evaluate(inner_with_value.get_value()))
        self.assertFalse(self.evaluate(inner_none.has_value()))
        self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))

    def _assertElementValueEqual(self, expected, actual):
        if isinstance(expected, dict):
            self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
            for k in expected.keys():
                self._assertElementValueEqual(expected[k], actual[k])
        elif isinstance(expected, sparse_tensor.SparseTensorValue):
            self.assertAllEqual(expected.indices, actual.indices)
            self.assertAllEqual(expected.values, actual.values)
            self.assertAllEqual(expected.dense_shape, actual.dense_shape)
        else:
            self.assertAllEqual(expected, actual)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), structure.SparseTensorStructure(
                dtypes.int32, [1])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.TensorStructure(dtypes.string, [1]),
                   structure.TensorStructure(dtypes.string, []))
         })),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testOptionalStructure(self, tf_value_fn, expected_value_structure):
        tf_value = tf_value_fn()
        opt = optional_ops.Optional.from_value(tf_value)

        self.assertTrue(
            expected_value_structure.is_compatible_with(opt.value_structure))
        self.assertTrue(
            opt.value_structure.is_compatible_with(expected_value_structure))

        opt_structure = structure.Structure.from_value(opt)
        self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
        self.assertTrue(opt_structure.is_compatible_with(opt_structure))
        self.assertTrue(
            opt_structure._value_structure.is_compatible_with(
                expected_value_structure))
        self.assertEqual([dtypes.variant], opt_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes)

        # All OptionalStructure objects are not compatible with a non-optional
        # value.
        non_optional_structure = structure.Structure.from_value(
            constant_op.constant(42.0))
        self.assertFalse(
            opt_structure.is_compatible_with(non_optional_structure))

        # Assert that the optional survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_opt = opt_structure._from_tensor_list(
            opt_structure._to_tensor_list(opt))
        if isinstance(tf_value, optional_ops.Optional):
            self.assertEqual(
                self.evaluate(tf_value.get_value()),
                self.evaluate(round_trip_opt.get_value().get_value()))
        else:
            self.assertEqual(self.evaluate(tf_value),
                             self.evaluate(round_trip_opt.get_value()))

    # NOTE: This test is specific to graph mode and is skipped in eager mode.
    @parameterized.named_parameters(
        ("Tensor", np.array([1, 2, 3], dtype=np.int32),
         lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
        ("SparseTensor",
         sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                         values=np.array([-1., 1.],
                                                         dtype=np.float32),
                                         dense_shape=[2, 2]),
         lambda: sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]],
                                            values=[37.0, 42.0],
                                            dense_shape=[2, 2]), False),
        ("Nest", {
            "a":
            np.array([1, 2, 3], dtype=np.int32),
            "b":
            sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                            values=np.array([-1., 1.],
                                                            dtype=np.float32),
                                            dense_shape=[2, 2])
        }, lambda: {
            "a":
            constant_op.constant([4, 5, 6], dtype=dtypes.int32),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]],
                                       values=[37.0, 42.0],
                                       dense_shape=[2, 2])
        }, False),
    )
    @test_util.run_deprecated_v1
    def testSkipEagerIteratorGetNextAsOptional(self, np_value, tf_value_fn,
                                               works_on_gpu):
        if not works_on_gpu and test.is_gpu_available():
            self.skipTest("Test case not yet supported on GPU.")
        ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
        iterator = ds.make_initializable_iterator()
        next_elem = iterator_ops.get_next_as_optional(iterator)
        self.assertIsInstance(next_elem, optional_ops.Optional)
        self.assertTrue(
            next_elem.value_structure.is_compatible_with(
                structure.Structure.from_value(tf_value_fn())))
        elem_has_value_t = next_elem.has_value()
        elem_value_t = next_elem.get_value()
        with self.cached_session() as sess:
            # Before initializing the iterator, evaluating the optional fails with
            # a FailedPreconditionError.
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(elem_has_value_t)
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(elem_value_t)

            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            sess.run(iterator.initializer)
            for _ in range(3):
                elem_has_value, elem_value = sess.run(
                    [elem_has_value_t, elem_value_t])
                self.assertTrue(elem_has_value)
                self._assertElementValueEqual(np_value, elem_value)

            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                self.assertFalse(sess.run(elem_has_value_t))
                with self.assertRaises(errors.InvalidArgumentError):
                    sess.run(elem_value_t)

    def testFunctionBoundaries(self):
        @def_function.function
        def get_optional():
            x = constant_op.constant(1.0)
            opt = optional_ops.Optional.from_value(x)
            # TODO(skyewm): support returning Optionals from functions?
            return opt._variant_tensor

        # TODO(skyewm): support Optional arguments?
        @def_function.function
        def consume_optional(opt_tensor):
            value_structure = structure.TensorStructure(dtypes.float32, [])
            opt = optional_ops._OptionalImpl(opt_tensor, value_structure)
            return opt.get_value()

        opt_tensor = get_optional()
        val = consume_optional(opt_tensor)
        self.assertEqual(self.evaluate(val), 1.0)
示例#7
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        graph = graph_pb2.GraphDef().FromString(
            self.evaluate(dataset._as_serialized_graph()))
        self.assertTrue(any([node.op != "RangeDataset"
                             for node in graph.node]))

    def testAsFunctionWithMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        with ops.device("CPU"):
            original_dataset = dataset_ops.Dataset.range(5).map(
                lambda x: x * 2)
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = _RevivedDataset(
                variant, original_dataset._element_structure)
            self.assertDatasetProduces(revived_dataset, range(0, 10, 2))

    def testAsFunctionWithMapInFlatMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        with ops.device("CPU"):
            original_dataset = dataset_ops.Dataset.range(5).flat_map(
                lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = _RevivedDataset(
                variant, original_dataset._element_structure)
            self.assertDatasetProduces(revived_dataset, list(original_dataset))

    @staticmethod
    def make_apply_fn(dataset):
        def apply_fn(dataset):
            def _apply_fn(dataset):
                return dataset.cache()

            return dataset.apply(_apply_fn)

        return apply_fn

    @staticmethod
    def make_gen():
        def gen():
            yield 42

        return gen

    @staticmethod
    def make_interleave_fn(dataset, num_parallel_calls=None):
        def interleave_fn(dataset):
            return dataset.interleave(lambda x: dataset_ops.Dataset.range(0),
                                      cycle_length=2,
                                      num_parallel_calls=num_parallel_calls)

        return interleave_fn

    @parameterized.named_parameters(
        ("FixedLengthRecord",
         lambda: readers.FixedLengthRecordDataset("", 42)),
        ("FromGenerator", lambda: dataset_ops.Dataset.from_generator(
            DatasetTest.make_gen(), dtypes.int32), 1),
        ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("Range", lambda: dataset_ops.Dataset.range(10)),
        ("TextLine", lambda: readers.TextLineDataset("")),
        ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
    )
    def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
        self.assertLen(dataset_fn()._inputs(), num_inputs)

    @test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
    def testDatasetComplexSourceInputs(self):
        dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
            sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                         [2, 0]]),
                                       values=np.array([0, 0, 0]),
                                       dense_shape=np.array([3, 1])))
        self.assertEmpty(dataset_fn._inputs())

    @parameterized.named_parameters(
        ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)),
        ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)),
        ("Filter", lambda x: x.filter(lambda x: True),
         lambda: dataset_ops.Dataset.range(0)),
        ("FlatMap",
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
         lambda: dataset_ops.Dataset.range(0)),
        ("Map", lambda x: x.map(lambda x: x),
         lambda: dataset_ops.Dataset.range(0)),
        ("PaddedBatch", lambda x: x.padded_batch(10, []),
         lambda: dataset_ops.Dataset.range(0)),
        ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
         lambda: dataset_ops.Dataset.range(0)),
        ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)),
        ("Shuffle", lambda x: x.shuffle(10),
         lambda: dataset_ops.Dataset.range(0)),
        ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)),
        ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)),
        ("Window", lambda x: x.window(10),
         lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
        input_dataset = input_dataset_fn()
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testUnaryTransformationInputsApply(self):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2
                                ], lambda: dataset_ops.Dataset.range(0)),
        ("Interleave", [lambda: dataset_ops.Dataset.range(0), None
                        ], lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args,
                                                      input_dataset_fn):
        input_dataset = input_dataset_fn()
        dataset_fn = self.make_interleave_fn(*interleave_fn_args)
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testNoWarnings(self):
        with test.mock.patch.object(warnings, "warn") as mock_log:
            dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10))
            dataset_fn(dataset_ops.Dataset.range(10))
            self.assertEmpty(mock_log.call_args_list)

    @parameterized.named_parameters(
        ("Concatenate", lambda x, y: x.concatenate(y),
         lambda: dataset_ops.Dataset.range(0),
         lambda: dataset_ops.Dataset.range(1)))
    def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
        input1 = input1_fn()
        input2 = input2_fn()
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @parameterized.named_parameters(
        ("ZipOne", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0))),
        ("ZipNest", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0),
          (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
        ("ZipTuple", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))),
    )
    def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
        input_datasets = input_datasets_fn()
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    def testFunctions(self):
        dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
        self.assertLen(dataset._functions(), 1)

    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), structure.SparseTensorStructure(
                dtypes.int32, [1])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.TensorStructure(dtypes.string, [1]),
                   structure.TensorStructure(dtypes.string, []))
         })),
        ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
            constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetStructure(
             structure.TensorStructure(dtypes.int32, []))),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testDatasetStructure(self, tf_value_fn, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(
            lambda _: tf_value_fn())
        dataset_structure = structure.Structure.from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

        # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
        # the element structure.
        self.assertTrue(
            expected_element_structure.is_compatible_with(
                dataset_structure._element_structure))
        self.assertTrue(
            dataset_structure._element_structure.is_compatible_with(
                expected_element_structure))

        self.assertEqual([dtypes.variant], dataset_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()],
                         dataset_structure._flat_shapes)

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value_fn()

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value_fn())],
                                       requires_initialization=True)

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorOneShot(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorOneShotSimple(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with test.mock.patch.object(logging, "warning") as mock_log:
                _ = dataset_ops.make_one_shot_iterator(dataset)
                self.assertRegexpMatches(
                    str(mock_log.call_args),
                    "Please ensure that all datasets in the "
                    "pipeline are created in the same graph as the iterator.")

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorInitializable(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    @parameterized.named_parameters(
        ("Async", context.ASYNC),
        ("Sync", context.SYNC),
    )
    def testDatasetEagerIteration(self, execution_mode):
        with context.eager_mode(), context.execution_mode(execution_mode):
            val = 0
            dataset = dataset_ops.Dataset.range(10)
            for foo in dataset:
                self.assertEqual(val, foo.numpy())
                val += 1

    def testDatasetAsFunctionArgument(self):
        @def_function.function
        def _uses_dataset(d):
            accumulator = array_ops.zeros([], dtype=dtypes.int64)
            for value in d:
                accumulator += value
            return accumulator

        with ops.device("CPU"):
            first_dataset = dataset_ops.Dataset.range(10)
            self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
            second_dataset = dataset_ops.Dataset.range(11)
            self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
            first_concrete = _uses_dataset.get_concrete_function(first_dataset)
            # The dataset should not be a captured input
            self.assertEmpty(first_concrete.graph.captures)
            # The two datasets have the same structure and so should re-use a trace.
            self.assertIs(first_concrete,
                          _uses_dataset.get_concrete_function(second_dataset))
            # With a different structure we should use a different trace.
            self.assertIsNot(
                first_concrete,
                _uses_dataset.get_concrete_function(
                    dataset_ops.Dataset.zip((first_dataset, second_dataset))))

    def testLimitedRetracingWithCompositeTensors(self):
        trace_count = [0]

        @def_function.function
        def f(ds):
            trace_count[0] += 1
            counter = np.int64(0)
            for elem in ds:
                counter += elem
            return counter

        dataset = dataset_ops.Dataset.range(5)
        dataset2 = dataset_ops.Dataset.range(10)

        for _ in range(10):
            self.assertEqual(self.evaluate(f(dataset)), 10)
            self.assertEqual(self.evaluate(f(dataset2)), 45)
            self.assertEqual(trace_count[0], 1)
示例#8
0
    def __init__(self,
                 input_dataset,
                 functions,
                 ratio_numerator=1,
                 ratio_denominator=1,
                 num_elements_per_branch=None):
        """Chooses the fastest of some dataset functions.

    Given dataset functions that take input_dataset as input and output
    another dataset, produces elements as quickly as the fastest of these
    output datasets. Note that datasets in the dataset functions are assumed
    to be stateless, and the iterators created by the functions' output datasets
    will, given the same input elements, all produce the same output elements.
    Datasets in the functions are also expected to iterate over the input
    dataset at most once. The violation of these conditions may lead to
    undefined behavior.

    For example:
    ```python
    dataset = tf.data.Dataset.range(100)
    dataset = _ChooseFastestDataset(
        dataset,
        [
            lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10),
            lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1]))
        ],
        ratio=10,
        num_elements_per_branch=10
    )
    ```
    The resulting dataset will produce elements equivalent to
    `tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or
    `tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))`

    Note that the first `num_elements_per_branch` iterations may be slower due
    to the
    overhead of dynamically picking the fastest dataset. Namely, for these
    iterations, the dataset will produce elements from any of branches to
    determine which input is the fastest. For all subsequent iterations, that
    input will be used.

    Args:
      input_dataset: A `Dataset` that can be used as input to `functions`.
      functions: A list of callables, each of which takes a `Dataset` as input
        and returns a `Dataset`.
      ratio_numerator: The numerator in the ratio of input elements consumed to
        output elements produced for each function. This should be the same for
        all functions. For example, if the function is
        `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
          must produce 10 elements for every element of the output dataset. In
          this case, ratio_numerator should be 10.
      ratio_denominator: The denominator in the ratio of input elements consumed
        to output elements produced for each function. This should be the same
        for all functions. For example, if the function is
        `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
          must produce 10 elements for every element of the output dataset. In
          this case, ratio_denominator should be 1.
      num_elements_per_branch: The number of elements to get from each branch
        before deciding which dataset is fastest. In the first len(functions) *
        num_elements_per_branch iterations, the dataset will call from one of
        the branches, and update its knowledge of which input is the fastest.
        Note that (num_elements_per_branch * ratio) is expected to be an
        integer.

    Returns:
      A `Dataset` that has the same elements the inputs.
    """
        nested_structure = structure_lib.NestedStructure(
            dataset_ops.DatasetStructure(
                structure_lib.convert_legacy_structure(
                    input_dataset.output_types, input_dataset.output_shapes,
                    input_dataset.output_classes)))
        self._funcs = [
            dataset_ops.StructuredFunctionWrapper(
                f, "ChooseFastestV2", input_structure=nested_structure)
            for f in functions
        ]
        self._structure = self._funcs[0].output_structure._element_structure  # pylint: disable=protected-access

        self._captured_arguments = []
        for f in self._funcs:
            self._captured_arguments.extend(f.function.captured_inputs)
        self._capture_lengths = [
            len(f.function.captured_inputs) for f in self._funcs
        ]

        if ratio_numerator <= 0 or ratio_denominator <= 0:
            raise ValueError("ratio must be positive.")

        if num_elements_per_branch is None:
            # Pick a sensible default based on `ratio_denominator`
            num_elements_per_branch = 10 * ratio_denominator

        variant_tensor = (
            gen_experimental_dataset_ops.choose_fastest_branch_dataset(
                input_dataset._variant_tensor,  # pylint: disable=protected-access
                ratio_numerator=ratio_numerator,
                ratio_denominator=ratio_denominator,
                other_arguments=self._captured_arguments,
                num_elements_per_branch=num_elements_per_branch,
                branches=[f.function for f in self._funcs],
                other_arguments_lengths=self._capture_lengths,
                **dataset_ops.flat_structure(self)))
        super(_ChooseFastestBranchDataset,
              self).__init__(input_dataset, variant_tensor)
示例#9
0
  def __init__(self, input_dataset, initial_state, scan_func):
    """See `scan()` for details."""
    self._input_dataset = input_dataset
    self._initial_state = structure.normalize_tensors(initial_state)

    # Compute initial values for the state classes, shapes and types based on
    # the initial state. The shapes may be refined by running `tf_scan_func` one
    # or more times below.
    self._state_structure = type_spec.type_spec_from_value(self._initial_state)

    # Iteratively rerun the scan function until reaching a fixed point on
    # `self._state_shapes`.
    need_to_rerun = True
    while need_to_rerun:

      wrapped_func = dataset_ops.StructuredFunctionWrapper(
          scan_func,
          self._transformation_name(),
          input_structure=structure.NestedStructure(
              (self._state_structure, input_dataset._element_structure)),  # pylint: disable=protected-access
          add_to_graph=False)
      if not (
          isinstance(wrapped_func.output_types, collections.Sequence) and
          len(wrapped_func.output_types) == 2):
        raise TypeError("The scan function must return a pair comprising the "
                        "new state and the output value.")

      new_state_classes, self._output_classes = wrapped_func.output_classes

      # Extract and validate class information from the returned values.
      new_state_classes, output_classes = wrapped_func.output_classes
      old_state_classes = self._state_structure._to_legacy_output_classes()  # pylint: disable=protected-access
      for new_state_class, old_state_class in zip(
          nest.flatten(new_state_classes),
          nest.flatten(old_state_classes)):
        if not issubclass(new_state_class, old_state_class):
          raise TypeError(
              "The element classes for the new state must match the initial "
              "state. Expected %s; got %s." %
              (old_state_classes, new_state_classes))

      # Extract and validate type information from the returned values.
      new_state_types, output_types = wrapped_func.output_types
      old_state_types = self._state_structure._to_legacy_output_types()  # pylint: disable=protected-access
      for new_state_type, old_state_type in zip(
          nest.flatten(new_state_types), nest.flatten(old_state_types)):
        if new_state_type != old_state_type:
          raise TypeError(
              "The element types for the new state must match the initial "
              "state. Expected %s; got %s." %
              (old_state_types, new_state_types))

      # Extract shape information from the returned values.
      new_state_shapes, output_shapes = wrapped_func.output_shapes
      old_state_shapes = self._state_structure._to_legacy_output_shapes()  # pylint: disable=protected-access
      self._structure = structure.convert_legacy_structure(
          output_types, output_shapes, output_classes)

      flat_state_shapes = nest.flatten(old_state_shapes)
      flat_new_state_shapes = nest.flatten(new_state_shapes)
      weakened_state_shapes = [
          original.most_specific_compatible_shape(new)
          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
      ]

      need_to_rerun = False
      for original_shape, weakened_shape in zip(flat_state_shapes,
                                                weakened_state_shapes):
        if original_shape.ndims is not None and (
            weakened_shape.ndims is None or
            original_shape.as_list() != weakened_shape.as_list()):
          need_to_rerun = True
          break

      if need_to_rerun:
        # TODO(b/110122868): Support a "most specific compatible structure"
        # method for combining structures, to avoid using legacy structures
        # in this method.
        self._state_structure = structure.convert_legacy_structure(
            old_state_types,
            nest.pack_sequence_as(old_state_shapes, weakened_state_shapes),
            old_state_classes)

    self._scan_func = wrapped_func
    self._scan_func.function.add_to_graph(ops.get_default_graph())
    # pylint: disable=protected-access
    variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset(
        self._input_dataset._variant_tensor,
        self._state_structure._to_tensor_list(self._initial_state),
        self._scan_func.function.captured_inputs,
        f=self._scan_func.function,
        preserve_cardinality=True,
        **dataset_ops.flat_structure(self))
    super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
示例#10
0
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
                    ragged_test_util.RaggedTensorTestCase):

  # pylint: disable=g-long-lambda,protected-access
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure,
       [dtypes.float32], [[]]),
      ("TensorArray", lambda: tensor_array_ops.TensorArray(
          dtype=dtypes.float32, element_shape=(3,), size=0),
       structure.TensorArrayStructure, [dtypes.variant], [None, 3]),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       structure.SparseTensorStructure, [dtypes.variant], [None]),
      ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [4]]),
       structure.RaggedTensorStructure, [dtypes.variant], [None]),
      ("Nested_0",
       lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
       structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
      ("Nested_1", lambda: {
          "a": constant_op.constant(37.0),
          "b": constant_op.constant([1, 2, 3])
      }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
      ("Nested_2", lambda: {
          "a":
              constant_op.constant(37.0),
          "b": (sparse_tensor.SparseTensor(
              indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
      }, structure.NestedStructure,
       [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None]),
  )
  def testFlatStructure(self, value_fn, expected_structure, expected_types,
                        expected_shapes):
    value = value_fn()
    s = structure.Structure.from_value(value)
    self.assertIsInstance(s, expected_structure)
    self.assertEqual(expected_types, s._flat_types)
    for expected, actual in zip(expected_shapes, s._flat_shapes):
      self.assertTrue(actual.is_compatible_with(expected))
      self.assertTrue(
          tensor_shape.as_shape(expected).is_compatible_with(actual))

  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0), lambda: [
          constant_op.constant(38.0),
          array_ops.placeholder(dtypes.float32),
          variables.Variable(100.0), 42.0,
          np.array(42.0, dtype=np.float32)
      ], lambda: [constant_op.constant([1.0, 2.0]),
                  constant_op.constant(37)]),
      ("TensorArray", lambda: tensor_array_ops.TensorArray(
          dtype=dtypes.float32, element_shape=(3,), size=0), lambda: [
              tensor_array_ops.TensorArray(
                  dtype=dtypes.float32, element_shape=(3,), size=0),
              tensor_array_ops.TensorArray(
                  dtype=dtypes.float32, element_shape=(3,), size=10)
          ], lambda: [
              tensor_array_ops.TensorArray(
                  dtype=dtypes.int32, element_shape=(3,), size=0),
              tensor_array_ops.TensorArray(
                  dtype=dtypes.float32, element_shape=(), size=0)
          ]),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       lambda: [
           sparse_tensor.SparseTensor(
               indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
           sparse_tensor.SparseTensorValue(
               indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
           array_ops.sparse_placeholder(dtype=dtypes.int32),
           array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None])
       ], lambda: [
           constant_op.constant(37, shape=[4, 5]),
           sparse_tensor.SparseTensor(
               indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
           array_ops.sparse_placeholder(
               dtype=dtypes.int32, shape=[None, None, None]),
           sparse_tensor.SparseTensor(
               indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
       ]),
      ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
       lambda: [
           ragged_factory_ops.constant([[1, 2], [3, 4], []]),
           ragged_factory_ops.constant([[1], [2, 3, 4], [5]]),
       ], lambda: [
           ragged_factory_ops.constant(1),
           ragged_factory_ops.constant([1, 2]),
           ragged_factory_ops.constant([[1], [2]]),
           ragged_factory_ops.constant([["a", "b"]]),
       ]),
      ("Nested", lambda: {
          "a": constant_op.constant(37.0),
          "b": constant_op.constant([1, 2, 3])
      }, lambda: [{
          "a": constant_op.constant(15.0),
          "b": constant_op.constant([4, 5, 6])
      }], lambda: [{
          "a": constant_op.constant(15.0),
          "b": constant_op.constant([4, 5, 6, 7])
      }, {
          "a": constant_op.constant(15),
          "b": constant_op.constant([4, 5, 6])
      }, {
          "a":
              constant_op.constant(15),
          "b":
              sparse_tensor.SparseTensor(
                  indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
      }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
  )
  @test_util.run_deprecated_v1
  def testIsCompatibleWithStructure(
      self, original_value_fn, compatible_values_fn, incompatible_values_fn):
    original_value = original_value_fn()
    compatible_values = compatible_values_fn()
    incompatible_values = incompatible_values_fn()
    s = structure.Structure.from_value(original_value)
    for compatible_value in compatible_values:
      self.assertTrue(
          s.is_compatible_with(
              structure.Structure.from_value(compatible_value)))
    for incompatible_value in incompatible_values:
      self.assertFalse(
          s.is_compatible_with(
              structure.Structure.from_value(incompatible_value)))

  @parameterized.named_parameters(
      ("Tensor",
       lambda: constant_op.constant(37.0),
       lambda: constant_op.constant(42.0),
       lambda: constant_op.constant([5])),
      ("TensorArray",
       lambda: tensor_array_ops.TensorArray(
           dtype=dtypes.float32, element_shape=(3,), size=0),
       lambda: tensor_array_ops.TensorArray(
           dtype=dtypes.float32, element_shape=(3,), size=0),
       lambda: tensor_array_ops.TensorArray(
           dtype=dtypes.int32, element_shape=(), size=0)),
      ("SparseTensor",
       lambda: sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[3]], values=[-1], dense_shape=[5]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])),
      ("RaggedTensor",
       lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]),
       lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]),
       lambda: ragged_factory_ops.constant([[[1]], [[2], [3]]],
                                           ragged_rank=1),
       lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]),
       lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])),
      ("Nested",
       lambda: {
           "a": constant_op.constant(37.0),
           "b": constant_op.constant([1, 2, 3])},
       lambda: {
           "a": constant_op.constant(42.0),
           "b": constant_op.constant([4, 5, 6])},
       lambda: {
           "a": constant_op.constant([1, 2, 3]),
           "b": constant_op.constant(37.0)
       }),
  )  # pyformat: disable
  def testStructureFromValueEquality(self, value1_fn, value2_fn,
                                     *not_equal_value_fns):
    # pylint: disable=g-generic-assert
    s1 = structure.Structure.from_value(value1_fn())
    s2 = structure.Structure.from_value(value2_fn())
    self.assertEqual(s1, s1)  # check __eq__ operator.
    self.assertEqual(s1, s2)  # check __eq__ operator.
    self.assertFalse(s1 != s1)  # check __ne__ operator.
    self.assertFalse(s1 != s2)  # check __ne__ operator.
    self.assertEqual(hash(s1), hash(s1))
    self.assertEqual(hash(s1), hash(s2))
    for value_fn in not_equal_value_fns:
      s3 = structure.Structure.from_value(value_fn())
      self.assertNotEqual(s1, s3)  # check __ne__ operator.
      self.assertNotEqual(s2, s3)  # check __ne__ operator.
      self.assertFalse(s1 == s3)  # check __eq_ operator.
      self.assertFalse(s2 == s3)  # check __eq_ operator.

  @parameterized.named_parameters(
      ("RaggedTensor_RaggedRank",
       structure.RaggedTensorStructure(dtypes.int32, None, 1),
       structure.RaggedTensorStructure(dtypes.int32, None, 2)),
      ("RaggedTensor_Shape",
       structure.RaggedTensorStructure(dtypes.int32, [3, None], 1),
       structure.RaggedTensorStructure(dtypes.int32, [5, None], 1)),
      ("RaggedTensor_DType",
       structure.RaggedTensorStructure(dtypes.int32, None, 1),
       structure.RaggedTensorStructure(dtypes.float32, None, 1)),
      )
  def testInequality(self, s1, s2):
    # pylint: disable=g-generic-assert
    self.assertNotEqual(s1, s2)  # check __ne__ operator.
    self.assertFalse(s1 == s2)  # check __eq__ operator.

  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])),
      ("TensorArray", lambda: tensor_array_ops.TensorArray(
          dtype=dtypes.float32, element_shape=(3,), size=0),
       lambda: tensor_array_ops.TensorArray(
           dtype=dtypes.float32, element_shape=(3,), size=0),
       lambda: tensor_array_ops.TensorArray(
           dtype=dtypes.int32, element_shape=(), size=0)),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda:
       sparse_tensor.SparseTensor(indices=[[3]], values=[-1], dense_shape=[5])),
      ("Nested", lambda: {
          "a": constant_op.constant(37.0),
          "b": constant_op.constant([1, 2, 3])
      }, lambda: {
          "a": constant_op.constant(42.0),
          "b": constant_op.constant([4, 5, 6])
      }, lambda: {
          "a": constant_op.constant([1, 2, 3]),
          "b": constant_op.constant(37.0)
      }),
  )
  def testHash(self, value1_fn, value2_fn, value3_fn):
    s1 = structure.Structure.from_value(value1_fn())
    s2 = structure.Structure.from_value(value2_fn())
    s3 = structure.Structure.from_value(value3_fn())
    self.assertEqual(hash(s1), hash(s1))
    self.assertEqual(hash(s1), hash(s2))
    self.assertNotEqual(hash(s1), hash(s3))
    self.assertNotEqual(hash(s2), hash(s3))

  @parameterized.named_parameters(
      (
          "Tensor",
          lambda: constant_op.constant(37.0),
      ),
      (
          "SparseTensor",
          lambda: sparse_tensor.SparseTensor(
              indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
      ),
      ("TensorArray", lambda: tensor_array_ops.TensorArray(
          dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
      ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),),
      (
          "Nested_0",
          lambda: {
              "a": constant_op.constant(37.0),
              "b": constant_op.constant([1, 2, 3])
          },
      ),
      (
          "Nested_1",
          lambda: {
              "a":
                  constant_op.constant(37.0),
              "b": (sparse_tensor.SparseTensor(
                  indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
                    sparse_tensor.SparseTensor(
                        indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
          },
      ),
  )
  def testRoundTripConversion(self, value_fn):
    value = value_fn()
    s = structure.Structure.from_value(value)

    def maybe_stack_ta(v):
      if isinstance(v, tensor_array_ops.TensorArray):
        return v.stack()
      else:
        return v

    before = self.evaluate(maybe_stack_ta(value))
    after = self.evaluate(
        maybe_stack_ta(s._from_tensor_list(s._to_tensor_list(value))))

    flat_before = nest.flatten(before)
    flat_after = nest.flatten(after)
    for b, a in zip(flat_before, flat_after):
      if isinstance(b, sparse_tensor.SparseTensorValue):
        self.assertAllEqual(b.indices, a.indices)
        self.assertAllEqual(b.values, a.values)
        self.assertAllEqual(b.dense_shape, a.dense_shape)
      elif isinstance(
          b,
          (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)):
        self.assertRaggedEqual(b, a)
      else:
        self.assertAllEqual(b, a)

  # pylint: enable=g-long-lambda

  def testIncompatibleStructure(self):
    # Define three mutually incompatible values/structures, and assert that:
    # 1. Using one structure to flatten a value with an incompatible structure
    #    fails.
    # 2. Using one structure to restructre a flattened value with an
    #    incompatible structure fails.
    value_tensor = constant_op.constant(42.0)
    s_tensor = structure.Structure.from_value(value_tensor)
    flat_tensor = s_tensor._to_tensor_list(value_tensor)

    value_sparse_tensor = sparse_tensor.SparseTensor(
        indices=[[0, 0]], values=[1], dense_shape=[1, 1])
    s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor)
    flat_sparse_tensor = s_sparse_tensor._to_tensor_list(value_sparse_tensor)

    value_nest = {
        "a": constant_op.constant(37.0),
        "b": constant_op.constant([1, 2, 3])
    }
    s_nest = structure.Structure.from_value(value_nest)
    flat_nest = s_nest._to_tensor_list(value_nest)

    with self.assertRaisesRegexp(
        ValueError, r"SparseTensor.* is not convertible to a tensor with "
        r"dtype.*float32.* and shape \(\)"):
      s_tensor._to_tensor_list(value_sparse_tensor)
    with self.assertRaisesRegexp(
        ValueError, r"Value \{.*\} is not convertible to a tensor with "
        r"dtype.*float32.* and shape \(\)"):
      s_tensor._to_tensor_list(value_nest)

    with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"):
      s_sparse_tensor._to_tensor_list(value_tensor)

    with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"):
      s_sparse_tensor._to_tensor_list(value_nest)

    with self.assertRaisesRegexp(
        ValueError, "Tensor.* not compatible with the nested structure "
        ".*TensorStructure.*TensorStructure"):
      s_nest._to_tensor_list(value_tensor)

    with self.assertRaisesRegexp(
        ValueError, "SparseTensor.* not compatible with the nested structure "
        ".*TensorStructure.*TensorStructure"):
      s_nest._to_tensor_list(value_sparse_tensor)

    with self.assertRaisesRegexp(
        ValueError, r"Cannot convert.*with dtype.*float32.* and shape \(\)"):
      s_tensor._from_tensor_list(flat_sparse_tensor)

    with self.assertRaisesRegexp(
        ValueError, "TensorStructure corresponds to a single tf.Tensor."):
      s_tensor._from_tensor_list(flat_nest)

    with self.assertRaisesRegexp(
        ValueError, "SparseTensorStructure corresponds to a single tf.variant "
        "vector of length 3."):
      s_sparse_tensor._from_tensor_list(flat_tensor)

    with self.assertRaisesRegexp(
        ValueError, "SparseTensorStructure corresponds to a single tf.variant "
        "vector of length 3."):
      s_sparse_tensor._from_tensor_list(flat_nest)

    with self.assertRaisesRegexp(
        ValueError, "Expected 2 flat values in NestedStructure but got 1."):
      s_nest._from_tensor_list(flat_tensor)

    with self.assertRaisesRegexp(
        ValueError, "Expected 2 flat values in NestedStructure but got 1."):
      s_nest._from_tensor_list(flat_sparse_tensor)

  def testIncompatibleNestedStructure(self):
    # Define three mutually incompatible nested values/structures, and assert
    # that:
    # 1. Using one structure to flatten a value with an incompatible structure
    #    fails.
    # 2. Using one structure to restructure a flattened value with an
    #    incompatible structure fails.

    value_0 = {
        "a": constant_op.constant(37.0),
        "b": constant_op.constant([1, 2, 3])
    }
    s_0 = structure.Structure.from_value(value_0)
    flat_s_0 = s_0._to_tensor_list(value_0)

    # `value_1` has compatible nested structure with `value_0`, but different
    # classes.
    value_1 = {
        "a":
            constant_op.constant(37.0),
        "b":
            sparse_tensor.SparseTensor(
                indices=[[0, 0]], values=[1], dense_shape=[1, 1])
    }
    s_1 = structure.Structure.from_value(value_1)
    flat_s_1 = s_1._to_tensor_list(value_1)

    # `value_2` has incompatible nested structure with `value_0` and `value_1`.
    value_2 = {
        "a":
            constant_op.constant(37.0),
        "b": (sparse_tensor.SparseTensor(
            indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
              sparse_tensor.SparseTensor(
                  indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
    }
    s_2 = structure.Structure.from_value(value_2)
    flat_s_2 = s_2._to_tensor_list(value_2)

    with self.assertRaisesRegexp(
        ValueError, "SparseTensor.* not compatible with the nested structure "
        ".*TensorStructure"):
      s_0._to_tensor_list(value_1)

    with self.assertRaisesRegexp(
        ValueError, "SparseTensor.*SparseTensor.* not compatible with the "
        "nested structure .*TensorStructure"):
      s_0._to_tensor_list(value_2)

    with self.assertRaisesRegexp(
        ValueError, "Tensor.* not compatible with the nested structure "
        ".*SparseTensorStructure"):
      s_1._to_tensor_list(value_0)

    with self.assertRaisesRegexp(
        ValueError, "SparseTensor.*SparseTensor.* not compatible with the "
        "nested structure .*TensorStructure"):
      s_0._to_tensor_list(value_2)

    # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
    # needs to account for "a" coming before or after "b". It might be worth
    # adding a deterministic repr for these error messages (among other
    # improvements).
    with self.assertRaisesRegexp(
        ValueError, "Tensor.*Tensor.* not compatible with the nested structure "
        ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
        "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"):
      s_2._to_tensor_list(value_0)

    with self.assertRaisesRegexp(
        ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* "
        "not compatible with the nested structure .*"
        "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
        "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"):
      s_2._to_tensor_list(value_1)

    with self.assertRaisesRegexp(
        ValueError, r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"):
      s_0._from_tensor_list(flat_s_1)

    with self.assertRaisesRegexp(
        ValueError, "Expected 2 flat values in NestedStructure but got 3."):
      s_0._from_tensor_list(flat_s_2)

    with self.assertRaisesRegexp(
        ValueError, "SparseTensorStructure corresponds to a single tf.variant "
        "vector of length 3."):
      s_1._from_tensor_list(flat_s_0)

    with self.assertRaisesRegexp(
        ValueError, "Expected 2 flat values in NestedStructure but got 3."):
      s_1._from_tensor_list(flat_s_2)

    with self.assertRaisesRegexp(
        ValueError, "Expected 3 flat values in NestedStructure but got 2."):
      s_2._from_tensor_list(flat_s_0)

    with self.assertRaisesRegexp(
        ValueError, "Expected 3 flat values in NestedStructure but got 2."):
      s_2._from_tensor_list(flat_s_1)

  @parameterized.named_parameters(
      ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
       structure.TensorStructure(dtypes.float32, [])),
      ("SparseTensor", dtypes.int32, tensor_shape.matrix(
          2, 2), sparse_tensor.SparseTensor,
       structure.SparseTensorStructure(dtypes.int32, [2, 2])),
      ("TensorArray_0", dtypes.int32, tensor_shape.as_shape(
          [None, True, 2, 2]), tensor_array_ops.TensorArray,
       structure.TensorArrayStructure(
           dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)),
      ("TensorArray_1", dtypes.int32, tensor_shape.as_shape(
          [True, None, 2, 2]), tensor_array_ops.TensorArray,
       structure.TensorArrayStructure(
           dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)),
      ("TensorArray_2", dtypes.int32, tensor_shape.as_shape(
          [True, False, 2, 2]), tensor_array_ops.TensorArray,
       structure.TensorArrayStructure(
           dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)),
      ("RaggedTensor", dtypes.int32, tensor_shape.matrix(2, 2),
       structure.RaggedTensorStructure(dtypes.int32, [2, 2], 1),
       structure.RaggedTensorStructure(dtypes.int32, [2, 2], 1)),
      ("Nested", {
          "a": dtypes.float32,
          "b": (dtypes.int32, dtypes.string)
      }, {
          "a": tensor_shape.scalar(),
          "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())
      }, {
          "a": ops.Tensor,
          "b": (sparse_tensor.SparseTensor, ops.Tensor)
      },
       structure.NestedStructure({
           "a":
               structure.TensorStructure(dtypes.float32, []),
           "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                 structure.TensorStructure(dtypes.string, []))
       })),
  )
  def testConvertLegacyStructure(self, output_types, output_shapes,
                                 output_classes, expected_structure):
    actual_structure = structure.convert_legacy_structure(
        output_types, output_shapes, output_classes)
    self.assertTrue(expected_structure.is_compatible_with(actual_structure))
    self.assertTrue(actual_structure.is_compatible_with(expected_structure))

  def testNestedNestedStructure(self):
    # Although `Structure.from_value()` will not construct one, a nested
    # structure containing nested `NestedStructure` objects can occur if a
    # structure is constructed manually.
    s = structure.NestedStructure(
        (structure.TensorStructure(dtypes.int64, []),
         structure.NestedStructure(
             (structure.TensorStructure(dtypes.float32, []),
              structure.TensorStructure(dtypes.string, [])))))

    int64_t = constant_op.constant(37, dtype=dtypes.int64)
    float32_t = constant_op.constant(42.0)
    string_t = constant_op.constant("Foo")

    nested_tensors = (int64_t, (float32_t, string_t))

    tensor_list = s._to_tensor_list(nested_tensors)
    for expected, actual in zip([int64_t, float32_t, string_t], tensor_list):
      self.assertIs(expected, actual)

    (actual_int64_t, (actual_float32_t, actual_string_t)) = s._from_tensor_list(
        tensor_list)
    self.assertIs(int64_t, actual_int64_t)
    self.assertIs(float32_t, actual_float32_t)
    self.assertIs(string_t, actual_string_t)

    (actual_int64_t, (actual_float32_t, actual_string_t)) = (
        s._from_compatible_tensor_list(tensor_list))
    self.assertIs(int64_t, actual_int64_t)
    self.assertIs(float32_t, actual_float32_t)
    self.assertIs(string_t, actual_string_t)

  @parameterized.named_parameters(
      ("Tensor", structure.TensorStructure(dtypes.float32, []), 32,
       structure.TensorStructure(dtypes.float32, [32])),
      ("TensorUnknown", structure.TensorStructure(dtypes.float32, []), None,
       structure.TensorStructure(dtypes.float32, [None])),
      ("SparseTensor", structure.SparseTensorStructure(dtypes.float32, [None]),
       32, structure.SparseTensorStructure(dtypes.float32, [32, None])),
      ("SparseTensorUnknown",
       structure.SparseTensorStructure(dtypes.float32, [4]), None,
       structure.SparseTensorStructure(dtypes.float32, [None, 4])),
      ("RaggedTensor",
       structure.RaggedTensorStructure(dtypes.float32, [2, None], 1), 32,
       structure.RaggedTensorStructure(dtypes.float32, [32, 2, None], 2)),
      ("RaggedTensorUnknown",
       structure.RaggedTensorStructure(dtypes.float32, [4, None], 1), None,
       structure.RaggedTensorStructure(dtypes.float32, [None, 4, None], 2)),
      ("Nested", structure.NestedStructure({
          "a": structure.TensorStructure(dtypes.float32, []),
          "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                structure.TensorStructure(dtypes.string, []))}), 128,
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, [128]),
           "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
                 structure.TensorStructure(dtypes.string, [128]))})),
  )
  def testBatch(self, element_structure, batch_size,
                expected_batched_structure):
    batched_structure = element_structure._batch(batch_size)
    self.assertTrue(
        batched_structure.is_compatible_with(expected_batched_structure))
    self.assertTrue(
        expected_batched_structure.is_compatible_with(batched_structure))

  @parameterized.named_parameters(
      ("Tensor", structure.TensorStructure(dtypes.float32, [32]),
       structure.TensorStructure(dtypes.float32, [])),
      ("TensorUnknown", structure.TensorStructure(dtypes.float32, [None]),
       structure.TensorStructure(dtypes.float32, [])),
      ("SparseTensor",
       structure.SparseTensorStructure(dtypes.float32, [32, None]),
       structure.SparseTensorStructure(dtypes.float32, [None])),
      ("SparseTensorUnknown",
       structure.SparseTensorStructure(dtypes.float32, [None, 4]),
       structure.SparseTensorStructure(dtypes.float32, [4])),
      ("RaggedTensor",
       structure.RaggedTensorStructure(dtypes.float32, [32, 4, None], 2),
       structure.RaggedTensorStructure(dtypes.float32, [4, None], 1)),
      ("RaggedTensorUnknown",
       structure.RaggedTensorStructure(dtypes.float32, [None, None, 4], 2),
       structure.RaggedTensorStructure(dtypes.float32, [None, 4], 1)),
      ("Nested", structure.NestedStructure({
          "a": structure.TensorStructure(dtypes.float32, [128]),
          "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
                structure.TensorStructure(dtypes.string, [None]))}),
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                 structure.TensorStructure(dtypes.string, []))})),
  )
  def testUnbatch(self, element_structure, expected_unbatched_structure):
    unbatched_structure = element_structure._unbatch()
    self.assertTrue(
        unbatched_structure.is_compatible_with(expected_unbatched_structure))
    self.assertTrue(
        expected_unbatched_structure.is_compatible_with(unbatched_structure))

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
       lambda: constant_op.constant([1.0, 2.0])),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[0]], values=[13], dense_shape=[2])),
      ("RaggedTensor",
       lambda: ragged_factory_ops.constant([[[1]], [[2]]]),
       lambda: ragged_factory_ops.constant([[1]])),
      ("Nest", lambda: (
          constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
          sparse_tensor.SparseTensor(
              indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])),
       lambda: (constant_op.constant([1.0, 2.0]), sparse_tensor.SparseTensor(
           indices=[[0]], values=[13], dense_shape=[2]))),
  )
  def testToBatchedTensorList(self, value_fn, element_0_fn):
    batched_value = value_fn()
    s = structure.Structure.from_value(batched_value)
    batched_tensor_list = s._to_batched_tensor_list(batched_value)

    # The batch dimension is 2 for all of the test cases.
    # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
    # tensors in which we store sparse tensors.
    for t in batched_tensor_list:
      if t.dtype != dtypes.variant:
        self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))

    # Test that the 0th element from the unbatched tensor is equal to the
    # expected value.
    expected_element_0 = self.evaluate(element_0_fn())
    unbatched_s = s._unbatch()
    actual_element_0 = unbatched_s._from_tensor_list(
        [t[0] for t in batched_tensor_list])

    for expected, actual in zip(
        nest.flatten(expected_element_0), nest.flatten(actual_element_0)):
      if sparse_tensor.is_sparse(expected):
        self.assertSparseValuesEqual(expected, actual)
      elif ragged_tensor.is_ragged(expected):
        self.assertRaggedEqual(expected, actual)
      else:
        self.assertAllEqual(expected, actual)
示例#11
0
  def __init__(self,
               filenames,
               record_defaults,
               compression_type=None,
               buffer_size=None,
               header=False,
               field_delim=",",
               use_quote_delim=True,
               na_value="",
               select_cols=None):
    """Creates a `CsvDataset` by reading and decoding CSV files.

    The elements of this dataset correspond to records from the file(s).
    RFC 4180 format is expected for CSV files
    (https://tools.ietf.org/html/rfc4180)
    Note that we allow leading and trailing spaces with int or float field.


    For example, suppose we have a file 'my_file0.csv' with four CSV columns of
    different data types:
    ```
    abcdefg,4.28E10,5.55E6,12
    hijklmn,-5.3E14,,2
    ```

    We can construct a CsvDataset from it as follows:
    ```python
    tf.enable_eager_execution()

     dataset = tf.data.experimental.CsvDataset(
        "my_file*.csv",
        [tf.float32,  # Required field, use dtype or empty tensor
         tf.constant([0.0], dtype=tf.float32),  # Optional field, default to 0.0
         tf.int32,  # Required field, use dtype or empty tensor
         ],
        select_cols=[1,2,3]  # Only parse last three columns
    )
    ```

    The expected output of its iterations is:
    ```python
    for element in dataset:
      print(element)

    >> (4.28e10, 5.55e6, 12)
    >> (-5.3e14, 0.0, 2)
    ```

    Args:
      filenames: A `tf.string` tensor containing one or more filenames.
      record_defaults: A list of default values for the CSV fields. Each item in
        the list is either a valid CSV `DType` (float32, float64, int32, int64,
        string), or a `Tensor` object with one of the above types. One per
        column of CSV data, with either a scalar `Tensor` default value for the
        column if it is optional, or `DType` or empty `Tensor` if required. If
        both this and `select_columns` are specified, these must have the same
        lengths, and `column_defaults` is assumed to be sorted in order of
        increasing column index.
      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
        `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
        compression.
      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
        to buffer while reading files. Defaults to 4MB.
      header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
        have header line(s) that should be skipped when parsing. Defaults to
        `False`.
      field_delim: (Optional.) A `tf.string` scalar containing the delimiter
        character that separates fields in a record. Defaults to `","`.
      use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
        double quotation marks as regular characters inside of string fields
        (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
      na_value: (Optional.) A `tf.string` scalar indicating a value that will
        be treated as NA/NaN.
      select_cols: (Optional.) A sorted list of column indices to select from
        the input data. If specified, only this subset of columns will be
        parsed. Defaults to parsing all columns.
    """
    self._filenames = ops.convert_to_tensor(
        filenames, dtype=dtypes.string, name="filenames")
    self._compression_type = convert.optional_param_to_tensor(
        "compression_type",
        compression_type,
        argument_default="",
        argument_dtype=dtypes.string)
    record_defaults = [
        constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
        for x in record_defaults
    ]
    self._record_defaults = ops.convert_n_to_tensor(
        record_defaults, name="record_defaults")
    self._buffer_size = convert.optional_param_to_tensor(
        "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
    self._header = ops.convert_to_tensor(
        header, dtype=dtypes.bool, name="header")
    self._field_delim = ops.convert_to_tensor(
        field_delim, dtype=dtypes.string, name="field_delim")
    self._use_quote_delim = ops.convert_to_tensor(
        use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
    self._na_value = ops.convert_to_tensor(
        na_value, dtype=dtypes.string, name="na_value")
    self._select_cols = convert.optional_param_to_tensor(
        "select_cols",
        select_cols,
        argument_default=[],
        argument_dtype=dtypes.int64,
    )
    self._structure = structure.NestedStructure(
        tuple(structure.TensorStructure(d.dtype, [])
              for d in self._record_defaults))
    variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
        filenames=self._filenames,
        record_defaults=self._record_defaults,
        buffer_size=self._buffer_size,
        header=self._header,
        output_shapes=self._structure._flat_shapes,  # pylint: disable=protected-access
        field_delim=self._field_delim,
        use_quote_delim=self._use_quote_delim,
        na_value=self._na_value,
        select_cols=self._select_cols,
        compression_type=self._compression_type)
    super(CsvDatasetV2, self).__init__(variant_tensor)
示例#12
0
  def _make_reduce_func(self, reduce_func, input_dataset):
    """Make wrapping defun for reduce_func."""

    # Iteratively rerun the reduce function until reaching a fixed point on
    # `self._state_structure`.
    self._state_structure = self._init_func.output_structure
    state_types = self._init_func.output_types
    state_shapes = self._init_func.output_shapes
    state_classes = self._init_func.output_classes
    need_to_rerun = True
    while need_to_rerun:

      wrapped_func = dataset_ops.StructuredFunctionWrapper(
          reduce_func,
          self._transformation_name(),
          input_structure=structure.NestedStructure(
              (self._state_structure, input_dataset._element_structure)),  # pylint: disable=protected-access
          add_to_graph=False)

      # Extract and validate class information from the returned values.
      for new_state_class, state_class in zip(
          nest.flatten(wrapped_func.output_classes),
          nest.flatten(state_classes)):
        if not issubclass(new_state_class, state_class):
          raise TypeError(
              "The element classes for the new state must match the initial "
              "state. Expected %s; got %s." %
              (self._state_classes, wrapped_func.output_classes))

      # Extract and validate type information from the returned values.
      for new_state_type, state_type in zip(
          nest.flatten(wrapped_func.output_types), nest.flatten(state_types)):
        if new_state_type != state_type:
          raise TypeError(
              "The element types for the new state must match the initial "
              "state. Expected %s; got %s." %
              (self._init_func.output_types, wrapped_func.output_types))

      # Extract shape information from the returned values.
      flat_state_shapes = nest.flatten(state_shapes)
      flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
      weakened_state_shapes = [
          original.most_specific_compatible_shape(new)
          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
      ]

      need_to_rerun = False
      for original_shape, weakened_shape in zip(flat_state_shapes,
                                                weakened_state_shapes):
        if original_shape.ndims is not None and (
            weakened_shape.ndims is None or
            original_shape.as_list() != weakened_shape.as_list()):
          need_to_rerun = True
          break

      if need_to_rerun:
        state_shapes = nest.pack_sequence_as(
            self._init_func.output_shapes, weakened_state_shapes)
        self._state_structure = structure.convert_legacy_structure(
            state_types, state_shapes, state_classes)

    self._reduce_func = wrapped_func
    self._reduce_func.function.add_to_graph(ops.get_default_graph())
示例#13
0
class OptionalTest(test.TestCase, parameterized.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def testFromValue(self):
    opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
    self.assertTrue(self.evaluate(opt.has_value()))
    self.assertEqual(37.0, self.evaluate(opt.get_value()))

  @test_util.run_in_graph_and_eager_modes
  def testFromStructuredValue(self):
    opt = optional_ops.Optional.from_value({
        "a": constant_op.constant(37.0),
        "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
    })
    self.assertTrue(self.evaluate(opt.has_value()))
    self.assertEqual({
        "a": 37.0,
        "b": ([b"Foo"], b"Bar")
    }, self.evaluate(opt.get_value()))

  @test_util.run_in_graph_and_eager_modes
  def testFromSparseTensor(self):
    st_0 = sparse_tensor.SparseTensorValue(
        indices=np.array([[0]]),
        values=np.array([0], dtype=np.int64),
        dense_shape=np.array([1]))
    st_1 = sparse_tensor.SparseTensorValue(
        indices=np.array([[0, 0], [1, 1]]),
        values=np.array([-1., 1.], dtype=np.float32),
        dense_shape=np.array([2, 2]))
    opt = optional_ops.Optional.from_value((st_0, st_1))
    self.assertTrue(self.evaluate(opt.has_value()))
    val_0, val_1 = opt.get_value()
    for expected, actual in [(st_0, val_0), (st_1, val_1)]:
      self.assertAllEqual(expected.indices, self.evaluate(actual.indices))
      self.assertAllEqual(expected.values, self.evaluate(actual.values))
      self.assertAllEqual(expected.dense_shape,
                          self.evaluate(actual.dense_shape))

  @test_util.run_in_graph_and_eager_modes
  def testFromNone(self):
    value_structure = structure.TensorStructure(dtypes.float32, [])
    opt = optional_ops.Optional.none_from_structure(value_structure)
    self.assertTrue(opt.value_structure.is_compatible_with(value_structure))
    self.assertFalse(
        opt.value_structure.is_compatible_with(
            structure.TensorStructure(dtypes.float32, [1])))
    self.assertFalse(
        opt.value_structure.is_compatible_with(
            structure.TensorStructure(dtypes.int32, [])))
    self.assertFalse(self.evaluate(opt.has_value()))
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(opt.get_value())

  @test_util.run_in_graph_and_eager_modes
  def testCopyToGPU(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    with ops.device("/cpu:0"):
      optional_with_value = optional_ops.Optional.from_value(
          (constant_op.constant(37.0), constant_op.constant("Foo"),
           constant_op.constant(42)))
      optional_none = optional_ops.Optional.none_from_structure(
          structure.TensorStructure(dtypes.float32, []))

    with ops.device("/gpu:0"):
      gpu_optional_with_value = optional_ops._OptionalImpl(
          array_ops.identity(optional_with_value._variant_tensor),
          optional_with_value.value_structure)
      gpu_optional_none = optional_ops._OptionalImpl(
          array_ops.identity(optional_none._variant_tensor),
          optional_none.value_structure)

      gpu_optional_with_value_has_value = gpu_optional_with_value.has_value()
      gpu_optional_with_value_values = gpu_optional_with_value.get_value()

      gpu_optional_none_has_value = gpu_optional_none.has_value()

    self.assertTrue(self.evaluate(gpu_optional_with_value_has_value))
    self.assertEqual((37.0, b"Foo", 42),
                     self.evaluate(gpu_optional_with_value_values))
    self.assertFalse(self.evaluate(gpu_optional_none_has_value))

  def _assertElementValueEqual(self, expected, actual):
    if isinstance(expected, dict):
      self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
      for k in expected.keys():
        self._assertElementValueEqual(expected[k], actual[k])
    elif isinstance(expected, sparse_tensor.SparseTensorValue):
      self.assertAllEqual(expected.indices, actual.indices)
      self.assertAllEqual(expected.values, actual.values)
      self.assertAllEqual(expected.dense_shape, actual.dense_shape)
    else:
      self.assertAllEqual(expected, actual)

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       structure.TensorStructure(dtypes.float32, [])),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[1]),
       structure.SparseTensorStructure(dtypes.int32, [1])),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.TensorStructure(dtypes.string, [1]),
                 structure.TensorStructure(dtypes.string, []))})),
      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
       optional_ops.OptionalStructure(
           structure.TensorStructure(dtypes.float32, []))),
  )
  def testOptionalStructure(self, tf_value_fn, expected_value_structure):
    tf_value = tf_value_fn()
    opt = optional_ops.Optional.from_value(tf_value)

    self.assertTrue(
        expected_value_structure.is_compatible_with(opt.value_structure))
    self.assertTrue(
        opt.value_structure.is_compatible_with(expected_value_structure))

    opt_structure = structure.Structure.from_value(opt)
    self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
    self.assertTrue(opt_structure.is_compatible_with(opt_structure))
    self.assertTrue(opt_structure._value_structure.is_compatible_with(
        expected_value_structure))
    self.assertEqual([dtypes.variant], opt_structure._flat_types)
    self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes)

    # All OptionalStructure objects are not compatible with a non-optional
    # value.
    non_optional_structure = structure.Structure.from_value(
        constant_op.constant(42.0))
    self.assertFalse(opt_structure.is_compatible_with(non_optional_structure))

    # Assert that the optional survives a round-trip via _from_tensor_list()
    # and _to_tensor_list().
    round_trip_opt = opt_structure._from_tensor_list(
        opt_structure._to_tensor_list(opt))
    if isinstance(tf_value, optional_ops.Optional):
      self.assertEqual(
          self.evaluate(tf_value.get_value()),
          self.evaluate(round_trip_opt.get_value().get_value()))
    else:
      self.assertEqual(
          self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value()))

  @parameterized.named_parameters(
      ("Tensor", np.array([1, 2, 3], dtype=np.int32),
       lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
      ("SparseTensor", sparse_tensor.SparseTensorValue(
          indices=[[0, 0], [1, 1]],
          values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
       False),
      ("Nest", {"a": np.array([1, 2, 3], dtype=np.int32),
                "b": sparse_tensor.SparseTensorValue(
                    indices=[[0, 0], [1, 1]],
                    values=np.array([-1., 1.], dtype=np.float32),
                    dense_shape=[2, 2])},
       lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32),
                "b": sparse_tensor.SparseTensor(
                    indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
                    dense_shape=[2, 2])}, False),
  )
  def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu):
    if not works_on_gpu and test.is_gpu_available():
      self.skipTest("Test case not yet supported on GPU.")
    ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
    iterator = ds.make_initializable_iterator()
    next_elem = iterator_ops.get_next_as_optional(iterator)
    self.assertIsInstance(next_elem, optional_ops.Optional)
    self.assertTrue(
        next_elem.value_structure.is_compatible_with(
            structure.Structure.from_value(tf_value_fn())))
    elem_has_value_t = next_elem.has_value()
    elem_value_t = next_elem.get_value()
    with self.cached_session() as sess:
      # Before initializing the iterator, evaluating the optional fails with
      # a FailedPreconditionError.
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(elem_has_value_t)
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(elem_value_t)

      # For each element of the dataset, assert that the optional evaluates to
      # the expected value.
      sess.run(iterator.initializer)
      for _ in range(3):
        elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
        self.assertTrue(elem_has_value)
        self._assertElementValueEqual(np_value, elem_value)

      # After exhausting the iterator, `next_elem.has_value()` will evaluate to
      # false, and attempting to get the value will fail.
      for _ in range(2):
        self.assertFalse(sess.run(elem_has_value_t))
        with self.assertRaises(errors.InvalidArgumentError):
          sess.run(elem_value_t)
class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        with self.cached_session() as sess:
            graph = graph_pb2.GraphDef().FromString(
                sess.run(dataset._as_serialized_graph()))
            self.assertTrue(
                any(node.op != "RangeDataset" for node in graph.node))

    @staticmethod
    def make_apply_fn(dataset):
        def apply_fn(dataset):
            def _apply_fn(dataset):
                return dataset.cache()

            return dataset.apply(_apply_fn)

        return apply_fn

    @staticmethod
    def make_gen():
        def gen():
            yield 42

        return gen

    @staticmethod
    def make_interleave_fn(dataset, num_parallel_calls=None):
        def interleave_fn(dataset):
            return dataset.interleave(lambda x: dataset_ops.Dataset.range(0),
                                      cycle_length=2,
                                      num_parallel_calls=num_parallel_calls)

        return interleave_fn

    @parameterized.named_parameters(
        ("FixedLengthRecord",
         lambda: readers.FixedLengthRecordDataset("", 42)),
        ("FromGenerator", lambda: dataset_ops.Dataset.from_generator(
            DatasetOpsTest.make_gen(), dtypes.int32), 1),
        ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("Range", lambda: dataset_ops.Dataset.range(10)),
        ("TextLine", lambda: readers.TextLineDataset("")),
        ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
    )
    def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
        self.assertEqual(num_inputs, len(dataset_fn()._inputs()))

    def testDatasetComplexSourceInputs(self):
        dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
            sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                         [2, 0]]),
                                       values=np.array([0, 0, 0]),
                                       dense_shape=np.array([3, 1])))
        self.assertEqual(0, len(dataset_fn._inputs()))

    @parameterized.named_parameters(
        ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)),
        ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)),
        ("Filter", lambda x: x.filter(lambda x: True),
         lambda: dataset_ops.Dataset.range(0)),
        ("FlatMap",
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
         lambda: dataset_ops.Dataset.range(0)),
        ("Map", lambda x: x.map(lambda x: x),
         lambda: dataset_ops.Dataset.range(0)),
        ("PaddedBatch", lambda x: x.padded_batch(10, []),
         lambda: dataset_ops.Dataset.range(0)),
        ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
         lambda: dataset_ops.Dataset.range(0)),
        ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)),
        ("Shuffle", lambda x: x.shuffle(10),
         lambda: dataset_ops.Dataset.range(0)),
        ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)),
        ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)),
        ("Window", lambda x: x.window(10),
         lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
        input_dataset = input_dataset_fn()
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testUnaryTransformationInputsApply(self):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2
                                ], lambda: dataset_ops.Dataset.range(0)),
        ("Interleave", [lambda: dataset_ops.Dataset.range(0), None
                        ], lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args,
                                                      input_dataset_fn):
        input_dataset = input_dataset_fn()
        dataset_fn = self.make_interleave_fn(*interleave_fn_args)
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("Concatenate", lambda x, y: x.concatenate(y),
         lambda: dataset_ops.Dataset.range(0),
         lambda: dataset_ops.Dataset.range(1)))
    def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
        input1 = input1_fn()
        input2 = input2_fn()
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @parameterized.named_parameters(
        ("ZipOne", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0))),
        ("ZipNest", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0),
          (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
        ("ZipTuple", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))),
    )
    def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
        input_datasets = input_datasets_fn()
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    def testOptionsDefault(self):
        ds = dataset_ops.Dataset.range(0)
        self.assertEqual(dataset_ops.Options(), ds.options())

    def testOptionsOnce(self):
        options = dataset_ops.Options()
        ds = dataset_ops.Dataset.range(0).with_options(options).cache()
        self.assertEqual(options, ds.options())

    def testOptionsTwiceSame(self):
        options = dataset_ops.Options()
        options.experimental_autotune = True
        ds = dataset_ops.Dataset.range(0).with_options(options).with_options(
            options)
        self.assertEqual(options, ds.options())

    def testOptionsTwiceDifferent(self):
        options1 = dataset_ops.Options()
        options1.experimental_autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_filter_fusion = False
        ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
            options2)
        self.assertTrue(ds.options().experimental_autotune)
        # Explicitly check that flag is False since assertFalse allows None
        self.assertIs(ds.options().experimental_filter_fusion, False)

    def testOptionsTwiceDifferentError(self):
        options1 = dataset_ops.Options()
        options1.experimental_autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_autotune = False
        with self.assertRaisesRegexp(ValueError,
                                     "Cannot merge incompatible values"):
            dataset_ops.Dataset.range(0).with_options(options1).with_options(
                options2)

    def testOptionsMergeOptionsFromMultipleInputs(self):
        options1 = dataset_ops.Options()
        options1.experimental_autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_filter_fusion = True
        ds = dataset_ops.Dataset.zip(
            (dataset_ops.Dataset.range(0).with_options(options1),
             dataset_ops.Dataset.range(0).with_options(options2)))
        self.assertTrue(ds.options().experimental_autotune)
        self.assertTrue(ds.options().experimental_filter_fusion)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), structure.SparseTensorStructure(
                dtypes.int32, [1])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.TensorStructure(dtypes.string, [1]),
                   structure.TensorStructure(dtypes.string, []))
         })),
        ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
            constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetStructure(
             structure.TensorStructure(dtypes.int32, []))),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testDatasetStructure(self, tf_value_fn, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(
            lambda _: tf_value_fn())
        dataset_structure = structure.Structure.from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

        # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
        # the element structure.
        self.assertTrue(
            expected_element_structure.is_compatible_with(
                dataset_structure._element_structure))
        self.assertTrue(
            dataset_structure._element_structure.is_compatible_with(
                expected_element_structure))

        self.assertEqual([dtypes.variant], dataset_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()],
                         dataset_structure._flat_shapes)

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value_fn()

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value_fn())],
                                       requires_initialization=True)