Пример #1
0
def _test_convert_legacy_structure_combinations():
    cases = [(dtypes.float32, tensor_shape.TensorShape([]), ops.Tensor,
              tensor_spec.TensorSpec([], dtypes.float32)),
             (dtypes.int32, tensor_shape.TensorShape([2, 2]),
              sparse_tensor.SparseTensor,
              sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)),
             (dtypes.int32, tensor_shape.TensorShape([None, True, 2, 2]),
              tensor_array_ops.TensorArray,
              tensor_array_ops.TensorArraySpec([2, 2],
                                               dtypes.int32,
                                               dynamic_size=None,
                                               infer_shape=True)),
             (dtypes.int32, tensor_shape.TensorShape([True, None, 2, 2]),
              tensor_array_ops.TensorArray,
              tensor_array_ops.TensorArraySpec([2, 2],
                                               dtypes.int32,
                                               dynamic_size=True,
                                               infer_shape=None)),
             (dtypes.int32, tensor_shape.TensorShape([True, False, 2, 2]),
              tensor_array_ops.TensorArray,
              tensor_array_ops.TensorArraySpec([2, 2],
                                               dtypes.int32,
                                               dynamic_size=True,
                                               infer_shape=False)),
             (dtypes.int32, tensor_shape.TensorShape([2, None]),
              ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1),
              ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)),
             ({
                 "a": dtypes.float32,
                 "b": (dtypes.int32, dtypes.string)
             }, {
                 "a":
                 tensor_shape.TensorShape([]),
                 "b":
                 (tensor_shape.TensorShape([2,
                                            2]), tensor_shape.TensorShape([]))
             }, {
                 "a": ops.Tensor,
                 "b": (sparse_tensor.SparseTensor, ops.Tensor)
             }, {
                 "a":
                 tensor_spec.TensorSpec([], dtypes.float32),
                 "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
                       tensor_spec.TensorSpec([], dtypes.string))
             })]

    def reduce_fn(x, y):
        output_types, output_shapes, output_classes, expected_structure = y
        return x + combinations.combine(output_types=output_types,
                                        output_shapes=output_shapes,
                                        output_classes=output_classes,
                                        expected_structure=expected_structure)

    return functools.reduce(reduce_fn, cases, [])
Пример #2
0
def convert_legacy_structure(output_types, output_shapes, output_classes):
    """Returns a `Structure` that represents the given legacy structure.

  This method provides a way to convert from the existing `Dataset` and
  `Iterator` structure-related properties to a `Structure` object. A "legacy"
  structure is represented by the `tf.data.Dataset.output_types`,
  `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes`
  properties.

  TODO(b/110122868): Remove this function once `Structure` is used throughout
  `tf.data`.

  Args:
    output_types: A nested structure of `tf.DType` objects corresponding to
      each component of a structured value.
    output_shapes: A nested structure of `tf.TensorShape` objects
      corresponding to each component a structured value.
    output_classes: A nested structure of Python `type` objects corresponding
      to each component of a structured value.

  Returns:
    A `Structure`.

  Raises:
    TypeError: If a structure cannot be built from the arguments, because one of
      the component classes in `output_classes` is not supported.
  """
    flat_types = nest.flatten(output_types)
    flat_shapes = nest.flatten(output_shapes)
    flat_classes = nest.flatten(output_classes)
    flat_ret = []
    for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
                                                 flat_classes):
        if isinstance(flat_class, type_spec.TypeSpec):
            flat_ret.append(flat_class)
        elif issubclass(flat_class, sparse_tensor.SparseTensor):
            flat_ret.append(
                sparse_tensor.SparseTensorSpec(flat_shape, flat_type))
        elif issubclass(flat_class, ops.Tensor):
            flat_ret.append(tensor_spec.TensorSpec(flat_shape, flat_type))
        elif issubclass(flat_class, tensor_array_ops.TensorArray):
            # We sneaked the dynamic_size and infer_shape into the legacy shape.
            flat_ret.append(
                tensor_array_ops.TensorArraySpec(
                    flat_shape[2:],
                    flat_type,
                    dynamic_size=tensor_shape.dimension_value(flat_shape[0]),
                    infer_shape=tensor_shape.dimension_value(flat_shape[1])))
        else:
            # NOTE(mrry): Since legacy structures produced by iterators only
            # comprise Tensors, SparseTensors, and nests, we do not need to
            # support all structure types here.
            raise TypeError(
                "Could not build a structure for output class {}. Make sure any "
                "component class in `output_classes` inherits from one of the "
                "following classes: `tf.TypeSpec`, `tf.sparse.SparseTensor`, "
                "`tf.Tensor`, `tf.TensorArray`.".format(flat_class.__name__))

    return nest.pack_sequence_as(output_classes, flat_ret)
Пример #3
0
def _TensorArrayStructure(dtype, element_shape, dynamic_size, infer_shape):
    return tensor_array_ops.TensorArraySpec(element_shape, dtype, dynamic_size,
                                            infer_shape)
Пример #4
0
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
                    test_util.TensorFlowTestCase):

    # pylint: disable=g-long-lambda,protected-access
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0), tensor_spec.TensorSpec,
         [dtypes.float32], [[]]),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         tensor_array_ops.TensorArraySpec, [dtypes.variant], [[]]),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         sparse_tensor.SparseTensorSpec, [dtypes.variant], [None]),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[1, 2], [], [4]]),
         ragged_tensor.RaggedTensorSpec, [dtypes.variant], [None]),
        ("Nested_0", lambda:
         (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), tuple,
         [dtypes.float32, dtypes.int32], [[], [3]]),
        ("Nested_1", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, dict, [dtypes.float32, dtypes.int32], [[], [3]]),
        ("Nested_2", lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, dict, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None,
                                                                    None]),
    )
    def testFlatStructure(self, value_fn, expected_structure, expected_types,
                          expected_shapes):
        value = value_fn()
        s = structure.type_spec_from_value(value)
        self.assertIsInstance(s, expected_structure)
        flat_types = structure.get_flat_tensor_types(s)
        self.assertEqual(expected_types, flat_types)
        flat_shapes = structure.get_flat_tensor_shapes(s)
        self.assertLen(flat_shapes, len(expected_shapes))
        for expected, actual in zip(expected_shapes, flat_shapes):
            if expected is None:
                self.assertEqual(actual.ndims, None)
            else:
                self.assertEqual(actual.as_list(), expected)

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0), lambda: [
            constant_op.constant(38.0),
            array_ops.placeholder(dtypes.float32),
            variables.Variable(100.0), 42.0,
            np.array(42.0, dtype=np.float32)
        ],
         lambda: [constant_op.constant([1.0, 2.0]),
                  constant_op.constant(37)]),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=10)
            ], lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.int32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(), size=0)
            ]),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [
                sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]],
                                           values=[10, -1],
                                           dense_shape=[4, 5]),
                sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]],
                                                values=[10, -1],
                                                dense_shape=[4, 5]),
                array_ops.sparse_placeholder(dtype=dtypes.int32),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None])
            ], lambda: [
                constant_op.constant(37, shape=[4, 5]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None, None]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
            ]),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), lambda: [
             ragged_factory_ops.constant([[1, 2], [3, 4], []]),
             ragged_factory_ops.constant([[1], [2, 3, 4], [5]]),
         ], lambda: [
             ragged_factory_ops.constant(1),
             ragged_factory_ops.constant([1, 2]),
             ragged_factory_ops.constant([[1], [2]]),
             ragged_factory_ops.constant([["a", "b"]]),
         ]),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6])
        }], lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6, 7])
        }, {
            "a": constant_op.constant(15),
            "b": constant_op.constant([4, 5, 6])
        }, {
            "a":
            constant_op.constant(15),
            "b":
            sparse_tensor.SparseTensor(
                indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
        }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
    )
    @test_util.run_deprecated_v1
    def testIsCompatibleWithStructure(self, original_value_fn,
                                      compatible_values_fn,
                                      incompatible_values_fn):
        original_value = original_value_fn()
        compatible_values = compatible_values_fn()
        incompatible_values = incompatible_values_fn()
        s = structure.type_spec_from_value(original_value)
        for compatible_value in compatible_values:
            self.assertTrue(
                structure.are_compatible(
                    s, structure.type_spec_from_value(compatible_value)))
        for incompatible_value in incompatible_values:
            self.assertFalse(
                structure.are_compatible(
                    s, structure.type_spec_from_value(incompatible_value)))

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         lambda: constant_op.constant(42.0),
         lambda: constant_op.constant([5])),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.int32, element_shape=(), size=0)),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3]], values=[-1], dense_shape=[5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]),
         lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]), lambda:
         ragged_factory_ops.constant([[[1]], [[2], [3]]], ragged_rank=1),
         lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]),
         lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: {
            "a": constant_op.constant(42.0),
            "b": constant_op.constant([4, 5, 6])
        }, lambda: {
            "a": constant_op.constant([1, 2, 3]),
            "b": constant_op.constant(37.0)
        }),
    )  # pyformat: disable
    def testStructureFromValueEquality(self, value1_fn, value2_fn,
                                       *not_equal_value_fns):
        # pylint: disable=g-generic-assert
        s1 = structure.type_spec_from_value(value1_fn())
        s2 = structure.type_spec_from_value(value2_fn())
        self.assertEqual(s1, s1)  # check __eq__ operator.
        self.assertEqual(s1, s2)  # check __eq__ operator.
        self.assertFalse(s1 != s1)  # check __ne__ operator.
        self.assertFalse(s1 != s2)  # check __ne__ operator.
        for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)):
            self.assertEqual(hash(c1), hash(c1))
            self.assertEqual(hash(c1), hash(c2))
        for value_fn in not_equal_value_fns:
            s3 = structure.type_spec_from_value(value_fn())
            self.assertNotEqual(s1, s3)  # check __ne__ operator.
            self.assertNotEqual(s2, s3)  # check __ne__ operator.
            self.assertFalse(s1 == s3)  # check __eq_ operator.
            self.assertFalse(s2 == s3)  # check __eq_ operator.

    @parameterized.named_parameters(
        ("RaggedTensor_RaggedRank",
         ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
         ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 2)),
        ("RaggedTensor_Shape",
         ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32, 1),
         ragged_tensor.RaggedTensorSpec([5, None], dtypes.int32, 1)),
        ("RaggedTensor_DType",
         ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
         ragged_tensor.RaggedTensorSpec(None, dtypes.float32, 1)),
    )
    def testRaggedStructureInequality(self, s1, s2):
        # pylint: disable=g-generic-assert
        self.assertNotEqual(s1, s2)  # check __ne__ operator.
        self.assertFalse(s1 == s2)  # check __eq__ operator.

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         lambda: constant_op.constant(42.0),
         lambda: constant_op.constant([5])),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.int32, element_shape=(), size=0)),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3]], values=[-1], dense_shape=[5])),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: {
            "a": constant_op.constant(42.0),
            "b": constant_op.constant([4, 5, 6])
        }, lambda: {
            "a": constant_op.constant([1, 2, 3]),
            "b": constant_op.constant(37.0)
        }),
    )
    def testHash(self, value1_fn, value2_fn, value3_fn):
        s1 = structure.type_spec_from_value(value1_fn())
        s2 = structure.type_spec_from_value(value2_fn())
        s3 = structure.type_spec_from_value(value3_fn())
        for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2),
                              nest.flatten(s3)):
            self.assertEqual(hash(c1), hash(c1))
            self.assertEqual(hash(c1), hash(c2))
            self.assertNotEqual(hash(c1), hash(c3))
            self.assertNotEqual(hash(c2), hash(c3))

    @parameterized.named_parameters(
        (
            "Tensor",
            lambda: constant_op.constant(37.0),
        ),
        (
            "SparseTensor",
            lambda: sparse_tensor.SparseTensor(
                indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
        ),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
        (
            "RaggedTensor",
            lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
        ),
        (
            "Nested_0",
            lambda: {
                "a": constant_op.constant(37.0),
                "b": constant_op.constant([1, 2, 3])
            },
        ),
        (
            "Nested_1",
            lambda: {
                "a":
                constant_op.constant(37.0),
                "b": (sparse_tensor.SparseTensor(
                    indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
                      sparse_tensor.SparseTensor(
                          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
            },
        ),
    )
    def testRoundTripConversion(self, value_fn):
        value = value_fn()
        s = structure.type_spec_from_value(value)

        def maybe_stack_ta(v):
            if isinstance(v, tensor_array_ops.TensorArray):
                return v.stack()
            else:
                return v

        before = self.evaluate(maybe_stack_ta(value))
        after = self.evaluate(
            maybe_stack_ta(
                structure.from_tensor_list(s,
                                           structure.to_tensor_list(s,
                                                                    value))))

        flat_before = nest.flatten(before)
        flat_after = nest.flatten(after)
        for b, a in zip(flat_before, flat_after):
            if isinstance(b, sparse_tensor.SparseTensorValue):
                self.assertAllEqual(b.indices, a.indices)
                self.assertAllEqual(b.values, a.values)
                self.assertAllEqual(b.dense_shape, a.dense_shape)
            elif isinstance(b, (ragged_tensor.RaggedTensor,
                                ragged_tensor_value.RaggedTensorValue)):
                self.assertAllEqual(b, a)
            else:
                self.assertAllEqual(b, a)

    # pylint: enable=g-long-lambda

    def preserveStaticShape(self):
        rt = ragged_factory_ops.constant([[1, 2], [], [3]])
        rt_s = structure.type_spec_from_value(rt)
        rt_after = structure.from_tensor_list(
            rt_s, structure.to_tensor_list(rt_s, rt))
        self.assertEqual(rt_after.row_splits.shape.as_list(),
                         rt.row_splits.shape.as_list())
        self.assertEqual(rt_after.values.shape.as_list(), [None])

        st = sparse_tensor.SparseTensor(indices=[[3, 4]],
                                        values=[-1],
                                        dense_shape=[4, 5])
        st_s = structure.type_spec_from_value(st)
        st_after = structure.from_tensor_list(
            st_s, structure.to_tensor_list(st_s, st))
        self.assertEqual(st_after.indices.shape.as_list(), [None, 2])
        self.assertEqual(st_after.values.shape.as_list(), [None])
        self.assertEqual(st_after.dense_shape.shape.as_list(),
                         st.dense_shape.shape.as_list())

    def testPreserveTensorArrayShape(self):
        ta = tensor_array_ops.TensorArray(dtype=dtypes.int32,
                                          size=1,
                                          element_shape=(3, ))
        ta_s = structure.type_spec_from_value(ta)
        ta_after = structure.from_tensor_list(
            ta_s, structure.to_tensor_list(ta_s, ta))
        self.assertEqual(ta_after.element_shape.as_list(), [3])

    def testPreserveInferredTensorArrayShape(self):
        ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1)
        # Shape is inferred from the write.
        ta = ta.write(0, [1, 2, 3])
        ta_s = structure.type_spec_from_value(ta)
        ta_after = structure.from_tensor_list(
            ta_s, structure.to_tensor_list(ta_s, ta))
        self.assertEqual(ta_after.element_shape.as_list(), [3])

    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.type_spec_from_value(value_tensor)
        flat_tensor = structure.to_tensor_list(s_tensor, value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor)
        flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor,
                                                      value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.type_spec_from_value(value_nest)
        flat_nest = structure.to_tensor_list(s_nest, value_nest)

        with self.assertRaisesRegex(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            structure.to_tensor_list(s_tensor, value_sparse_tensor)
        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_tensor, value_nest)

        with self.assertRaisesRegex(
                TypeError, "Neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_sparse_tensor, value_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_sparse_tensor, value_nest)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_sparse_tensor)

        with self.assertRaisesRegex(ValueError, r"Incompatible input:"):
            structure.from_tensor_list(s_tensor, flat_sparse_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_tensor, flat_nest)

        with self.assertRaisesRegex(ValueError, "Incompatible input: "):
            structure.from_tensor_list(s_sparse_tensor, flat_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_sparse_tensor, flat_nest)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_sparse_tensor)

    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.type_spec_from_value(value_0)
        flat_s_0 = structure.to_tensor_list(s_0, value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.type_spec_from_value(value_1)
        flat_s_1 = structure.to_tensor_list(s_1, value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.type_spec_from_value(value_2)
        flat_s_2 = structure.to_tensor_list(s_2, value_2)

        with self.assertRaisesRegex(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*int32.* and shape \(3,\)"):
            structure.to_tensor_list(s_0, value_1)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_0, value_2)

        with self.assertRaisesRegex(
                TypeError, "Neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_1, value_0)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_1, value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_0)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_1)

        with self.assertRaisesRegex(ValueError, r"Incompatible input:"):
            structure.from_tensor_list(s_0, flat_s_1)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 3."):
            structure.from_tensor_list(s_0, flat_s_2)

        with self.assertRaisesRegex(ValueError, "Incompatible input: "):
            structure.from_tensor_list(s_1, flat_s_0)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 3."):
            structure.from_tensor_list(s_1, flat_s_2)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 3 tensors but got 2."):
            structure.from_tensor_list(s_2, flat_s_0)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 3 tensors but got 2."):
            structure.from_tensor_list(s_2, flat_s_1)

    @parameterized.named_parameters(
        ("Tensor", dtypes.float32, tensor_shape.TensorShape(
            []), ops.Tensor, tensor_spec.TensorSpec([], dtypes.float32)),
        ("SparseTensor", dtypes.int32, tensor_shape.TensorShape(
            [2, 2]), sparse_tensor.SparseTensor,
         sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)),
        ("TensorArray_0", dtypes.int32,
         tensor_shape.TensorShape([None, True, 2, 2
                                   ]), tensor_array_ops.TensorArray,
         tensor_array_ops.TensorArraySpec(
             [2, 2], dtypes.int32, dynamic_size=None, infer_shape=True)),
        ("TensorArray_1", dtypes.int32,
         tensor_shape.TensorShape([True, None, 2, 2
                                   ]), tensor_array_ops.TensorArray,
         tensor_array_ops.TensorArraySpec(
             [2, 2], dtypes.int32, dynamic_size=True, infer_shape=None)),
        ("TensorArray_2", dtypes.int32,
         tensor_shape.TensorShape([True, False, 2, 2
                                   ]), tensor_array_ops.TensorArray,
         tensor_array_ops.TensorArraySpec(
             [2, 2], dtypes.int32, dynamic_size=True, infer_shape=False)),
        ("RaggedTensor", dtypes.int32, tensor_shape.TensorShape([2, None]),
         ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1),
         ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)),
        ("Nested", {
            "a": dtypes.float32,
            "b": (dtypes.int32, dtypes.string)
        }, {
            "a": tensor_shape.TensorShape([]),
            "b":
            (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([]))
        }, {
            "a": ops.Tensor,
            "b": (sparse_tensor.SparseTensor, ops.Tensor)
        }, {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([], dtypes.string))
        }),
    )
    def testConvertLegacyStructure(self, output_types, output_shapes,
                                   output_classes, expected_structure):
        actual_structure = structure.convert_legacy_structure(
            output_types, output_shapes, output_classes)
        self.assertEqual(actual_structure, expected_structure)

    def testNestedNestedStructure(self):
        s = (tensor_spec.TensorSpec([], dtypes.int64),
             (tensor_spec.TensorSpec([], dtypes.float32),
              tensor_spec.TensorSpec([], dtypes.string)))

        int64_t = constant_op.constant(37, dtype=dtypes.int64)
        float32_t = constant_op.constant(42.0)
        string_t = constant_op.constant("Foo")

        nested_tensors = (int64_t, (float32_t, string_t))

        tensor_list = structure.to_tensor_list(s, nested_tensors)
        for expected, actual in zip([int64_t, float32_t, string_t],
                                    tensor_list):
            self.assertIs(expected, actual)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = structure.from_tensor_list(s, tensor_list)
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = (structure.from_compatible_tensor_list(
              s, tensor_list))
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

    @parameterized.named_parameters(
        ("Tensor", tensor_spec.TensorSpec([], dtypes.float32), 32,
         tensor_spec.TensorSpec([32], dtypes.float32)),
        ("TensorUnknown", tensor_spec.TensorSpec([], dtypes.float32), None,
         tensor_spec.TensorSpec([None], dtypes.float32)),
        ("SparseTensor", sparse_tensor.SparseTensorSpec([None],
                                                        dtypes.float32), 32,
         sparse_tensor.SparseTensorSpec([32, None], dtypes.float32)),
        ("SparseTensorUnknown",
         sparse_tensor.SparseTensorSpec([4], dtypes.float32), None,
         sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32)),
        ("RaggedTensor",
         ragged_tensor.RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
         ragged_tensor.RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
        ("RaggedTensorUnknown",
         ragged_tensor.RaggedTensorSpec([4, None], dtypes.float32, 1), None,
         ragged_tensor.RaggedTensorSpec([None, 4, None], dtypes.float32, 2)),
        ("Nested", {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([], dtypes.string))
        }, 128, {
            "a":
            tensor_spec.TensorSpec([128], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([128], dtypes.string))
        }),
    )
    def testBatch(self, element_structure, batch_size,
                  expected_batched_structure):
        batched_structure = nest.map_structure(
            lambda component_spec: component_spec._batch(batch_size),
            element_structure)
        self.assertEqual(batched_structure, expected_batched_structure)

    @parameterized.named_parameters(
        ("Tensor", tensor_spec.TensorSpec(
            [32], dtypes.float32), tensor_spec.TensorSpec([], dtypes.float32)),
        ("TensorUnknown", tensor_spec.TensorSpec([None], dtypes.float32),
         tensor_spec.TensorSpec([], dtypes.float32)),
        ("SparseTensor",
         sparse_tensor.SparseTensorSpec([32, None], dtypes.float32),
         sparse_tensor.SparseTensorSpec([None], dtypes.float32)),
        ("SparseTensorUnknown",
         sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32),
         sparse_tensor.SparseTensorSpec([4], dtypes.float32)),
        ("RaggedTensor",
         ragged_tensor.RaggedTensorSpec([32, None, None], dtypes.float32, 2),
         ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
        ("RaggedTensorUnknown",
         ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.float32, 2),
         ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
        ("Nested", {
            "a":
            tensor_spec.TensorSpec([128], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([None], dtypes.string))
        }, {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([], dtypes.string))
        }),
    )
    def testUnbatch(self, element_structure, expected_unbatched_structure):
        unbatched_structure = nest.map_structure(
            lambda component_spec: component_spec._unbatch(),
            element_structure)
        self.assertEqual(unbatched_structure, expected_unbatched_structure)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
         lambda: constant_op.constant([1.0, 2.0])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[0]], values=[13], dense_shape=[2])),
        ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]),
         lambda: ragged_factory_ops.constant([[1]])),
        ("Nest", lambda:
         (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
          sparse_tensor.SparseTensor(
              indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])),
         lambda: (constant_op.constant([1.0, 2.0]),
                  sparse_tensor.SparseTensor(
                      indices=[[0]], values=[13], dense_shape=[2]))),
    )
    def testToBatchedTensorList(self, value_fn, element_0_fn):
        batched_value = value_fn()
        s = structure.type_spec_from_value(batched_value)
        batched_tensor_list = structure.to_batched_tensor_list(
            s, batched_value)

        # The batch dimension is 2 for all of the test cases.
        # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
        # tensors in which we store sparse tensors.
        for t in batched_tensor_list:
            if t.dtype != dtypes.variant:
                self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))

        # Test that the 0th element from the unbatched tensor is equal to the
        # expected value.
        expected_element_0 = self.evaluate(element_0_fn())
        unbatched_s = nest.map_structure(
            lambda component_spec: component_spec._unbatch(), s)
        actual_element_0 = structure.from_tensor_list(
            unbatched_s, [t[0] for t in batched_tensor_list])

        for expected, actual in zip(nest.flatten(expected_element_0),
                                    nest.flatten(actual_element_0)):
            self.assertValuesEqual(expected, actual)

    # pylint: enable=g-long-lambda

    def testDatasetSpecConstructor(self):
        rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
        st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
        t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
        element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
        ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
        self.assertEqual(ds_struct._element_spec, element_spec)
        # Note: shape was automatically converted from a list to a TensorShape.
        self.assertEqual(ds_struct._dataset_shape,
                         tensor_shape.TensorShape([5]))

    def testCustomMapping(self):
        elem = CustomMap(foo=constant_op.constant(37.))
        spec = structure.type_spec_from_value(elem)
        self.assertIsInstance(spec, CustomMap)
        self.assertEqual(spec["foo"], tensor_spec.TensorSpec([],
                                                             dtypes.float32))

    def testObjectProxy(self):
        nt_type = collections.namedtuple("A", ["x", "y"])
        proxied = wrapt.ObjectProxy(nt_type(1, 2))
        proxied_spec = structure.type_spec_from_value(proxied)
        self.assertEqual(structure.type_spec_from_value(nt_type(1, 2)),
                         proxied_spec)