def testConcatTuple(self):
     values = (StructuredTensor.from_pyval([{
         "a": 3
     }]), StructuredTensor.from_pyval([{
         "a": 4
     }]))
     actual = array_ops.concat(values, axis=0)
     self.assertAllEqual(actual, [{"a": 3}, {"a": 4}])
 def testRandomShuffle2022Eager(self):
     original = StructuredTensor.from_pyval([{
         "x0": 0,
         "y": {
             "z": [[3, 13]]
         }
     }, {
         "x0": 1,
         "y": {
             "z": [[3], [4, 13]]
         }
     }, {
         "x0": 2,
         "y": {
             "z": [[3, 5], [4]]
         }
     }, {
         "x0": 3,
         "y": {
             "z": [[3, 7, 1], [4]]
         }
     }, {
         "x0": 4,
         "y": {
             "z": [[3], [4]]
         }
     }])  # pyformat: disable
     expected = StructuredTensor.from_pyval([{
         "x0": 1,
         "y": {
             "z": [[3], [4, 13]]
         }
     }, {
         "x0": 0,
         "y": {
             "z": [[3, 13]]
         }
     }, {
         "x0": 3,
         "y": {
             "z": [[3, 7, 1], [4]]
         }
     }, {
         "x0": 4,
         "y": {
             "z": [[3], [4]]
         }
     }, {
         "x0": 2,
         "y": {
             "z": [[3, 5], [4]]
         }
     }])  # pyformat: disable
     random_seed.set_seed(1066)
     result = structured_array_ops.random_shuffle(original, seed=2022)
     self.assertAllEqual(result, expected)
Example #3
0
    def testPartitionOuterDims(self):
        if not context.executing_eagerly(): return  # TESTING
        a = dict(x=1, y=[1, 2])
        b = dict(x=2, y=[3, 4])
        c = dict(x=3, y=[5, 6])
        d = dict(x=4, y=[7, 8])
        st1 = StructuredTensor.from_pyval([a, b, c, d])

        st2 = st1.partition_outer_dimension(
            row_partition.RowPartition.from_row_splits([0, 2, 2, 3, 4]))
        self.assertAllEqual(st2, [[a, b], [], [c], [d]])

        st3 = st2.partition_outer_dimension(
            row_partition.RowPartition.from_row_lengths([1, 0, 3, 0]))
        self.assertAllEqual(st3, [[[a, b]], [], [[], [c], [d]], []])

        # If we partition with uniform_row_lengths, then `x` is partitioned into
        # a Tensor (not a RaggedTensor).
        st4 = st1.partition_outer_dimension(
            row_partition.RowPartition.from_uniform_row_length(
                uniform_row_length=2, nvals=4, nrows=2))
        self.assertAllEqual(
            st4,
            structured_tensor.StructuredTensor.from_pyval(
                [[a, b], [c, d]],
                structured_tensor.StructuredTensorSpec(
                    [2, 2], {
                        "x":
                        tensor_spec.TensorSpec([2, 2], dtypes.int32),
                        "y":
                        ragged_tensor.RaggedTensorSpec([2, 2, None],
                                                       dtypes.int32)
                    })))
 def testExtendOpErrorNotList(self):
   # Should be a list.
   values = StructuredTensor.from_pyval({})
   def leaf_op(values):
     return values[0]
   with self.assertRaisesRegex(ValueError, "Expected a list"):
     structured_array_ops._extend_op(values, leaf_op)
 def testTupleFieldValue(self):
     st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
     self.assertAllEqual(st.field_value(("a", )), 5)
     self.assertAllEqual(st.field_value(("b", "c")), [1, 2, 3])
     expected = "Field path \(.*a.*,.*b.*\) not found in .*"
     with self.assertRaisesRegexp(KeyError, expected):
         st.field_value(("a", "b"))
Example #6
0
 def testTupleFieldValue(self):
     st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
     self.assertAllEqual(st.field_value(("a", )), 5)
     self.assertAllEqual(st.field_value(("b", "c")), [1, 2, 3])
     with self.assertRaisesRegexp(
             KeyError, r"Field path \('a', 'b'\) not found in .*"):
         st.field_value(("a", "b"))
 def testExpandDimsScalar(self):
   # Note that if we expand_dims for the final dimension and there are scalar
   # fields, then the shape is (2, None, None, 1), whereas if it is constructed
   # from pyval it is (2, None, None, None).
   st = [[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]]
   st = StructuredTensor.from_pyval(st)
   result = array_ops.expand_dims(st, 3)
   expected_shape = tensor_shape.TensorShape([2, None, None, 1])
   self.assertEqual(repr(expected_shape), repr(result.shape))
  def testSizeAlt(self, values, dtype, expected):
    st = StructuredTensor.from_pyval(values)
    # NOTE: size is very robust. There aren't arguments that
    # should cause this operation to fail.
    actual = array_ops.size(st, out_type=dtype)
    self.assertAllEqual(actual, expected)

    actual2 = array_ops.size_v2(st, out_type=dtype)
    self.assertAllEqual(actual2, expected)
  def testOnesLikeObjectAlt(self, values, dtype, expected):
    st = StructuredTensor.from_pyval(values)
    # NOTE: ones_like is very robust. There aren't arguments that
    # should cause this operation to fail.
    actual = array_ops.ones_like(st, dtype)
    self.assertAllEqual(actual, expected)

    actual2 = array_ops.ones_like_v2(st, dtype)
    self.assertAllEqual(actual2, expected)
    def testRandomShuffleScalarError(self):
        original = StructuredTensor.from_pyval({
            "x0": 2,
            "y": {
                "z": [[3, 5], [4]]
            }
        })  # pyformat: disable

        with self.assertRaisesRegex(ValueError, "scalar"):
            random_ops.random_shuffle(original)
 def testGather(self, params, indices, axis, batch_dims, expected):
     params = StructuredTensor.from_pyval(params)
     # validate_indices isn't actually used, and we aren't testing names
     actual = array_ops.gather(params,
                               indices,
                               validate_indices=True,
                               axis=axis,
                               name=None,
                               batch_dims=batch_dims)
     self.assertAllEqual(actual, expected)
 def testGatherError(self, params, indices, axis, batch_dims, error_type,
                     error_regex):
     params = StructuredTensor.from_pyval(params)
     with self.assertRaisesRegex(error_type, error_regex):
         structured_array_ops.gather(params,
                                     indices,
                                     validate_indices=True,
                                     axis=axis,
                                     name=None,
                                     batch_dims=batch_dims)
 def testGatherRagged(self, params, indices, axis, batch_dims, expected):
     params = StructuredTensor.from_pyval(params)
     # Shouldn't need to do this, but see cl/366396997
     indices = ragged_factory_ops.constant(indices)
     # validate_indices isn't actually used, and we aren't testing names
     actual = array_ops.gather(params,
                               indices,
                               validate_indices=True,
                               axis=axis,
                               name=None,
                               batch_dims=batch_dims)
     self.assertAllEqual(actual, expected)
Example #14
0
 def _lambda_for_fields(self):
     return lambda: {
         'a':
         np.ones([1, 2, 3, 1]),
         'b':
         np.ones([1, 2, 3, 1, 5]),
         'c':
         ragged_factory_ops.constant(np.zeros([1, 2, 3, 1], dtype=np.uint8),
                                     dtype=dtypes.uint8),
         'd':
         ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3]).tolist(),
                                     ragged_rank=1),
         'e':
         ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 2, 2]).tolist(),
                                     ragged_rank=2),
         'f':
         ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3]),
                                     dtype=dtypes.float32),
         'g':
         StructuredTensor.from_pyval([[
             [  # pylint: disable=g-complex-comprehension
                 [{
                     'x': j,
                     'y': k
                 }] for k in range(3)
             ] for j in range(2)
         ]]),
         'h':
         StructuredTensor.from_pyval([[
             [  # pylint: disable=g-complex-comprehension
                 [[{
                     'x': j,
                     'y': k,
                     'z': z
                 } for z in range(j)]] for k in range(3)
             ] for j in range(2)
         ]]),
     }
Example #15
0
 def testRepr(self):
   st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
   if context.executing_eagerly():
     expected = ("<StructuredTensor(fields={"
                 '"a": tf.Tensor(5, shape=(), dtype=int32), '
                 '"b": <StructuredTensor(fields={'
                 '"c": tf.Tensor([1 2 3], shape=(3,), dtype=int32)}, '
                 "shape=())>}, shape=())>")
   else:
     expected = ("<StructuredTensor(fields={"
                 '"a": Tensor("Const:0", shape=(), dtype=int32), '
                 '"b": <StructuredTensor(fields={'
                 '"c": Tensor("RaggedConstant/Const:0", shape=(3,), '
                 "dtype=int32)}, shape=())>}, shape=())>")
   self.assertEqual(repr(st), expected)
 def testConcatNotAList(self):
   values = StructuredTensor.from_pyval({})
   with self.assertRaisesRegex(
       ValueError, "values must be a list of StructuredTensors"):
     structured_array_ops.concat(values, 0)
 def testExpandDims(self, st, axis, expected):
   st = StructuredTensor.from_pyval(st)
   result = array_ops.expand_dims(st, axis)
   self.assertAllEqual(result, expected)
 def testConcatError(self, values, axis, error_type, error_regex):
   values = [StructuredTensor.from_pyval(v) for v in values]
   with self.assertRaisesRegex(error_type, error_regex):
     array_ops.concat(values, axis)
 def testConcatWithRagged(self):
   values = [StructuredTensor.from_pyval({}), array_ops.constant(3)]
   with self.assertRaisesRegex(ValueError,
                               "values must be a list of StructuredTensors"):
     array_ops.concat(values, 0)
 def testRankAlt(self, values, expected):
   st = StructuredTensor.from_pyval(values)
   # NOTE: rank is very robust. There aren't arguments that
   # should cause this operation to fail.
   actual = array_ops.rank(st)
   self.assertAllEqual(expected, actual)
 def testConcat(self, values, axis, expected):
   values = [StructuredTensor.from_pyval(v) for v in values]
   actual = array_ops.concat(values, axis)
   self.assertAllEqual(actual, expected)
Example #22
0
 def testMergeDims(self, st, outer_axis, inner_axis, expected):
     st = StructuredTensor.from_pyval(st)
     result = st.merge_dims(outer_axis, inner_axis)
     self.assertAllEqual(result, expected)
Example #23
0
 def testMergeDimsError(self):
     st = StructuredTensor.from_pyval([[[{"a": 5}]]])
     with self.assertRaisesRegexp(
             ValueError,
             r"Expected outer_axis \(2\) to be less than inner_axis \(1\)"):
         st.merge_dims(2, 1)
Example #24
0
 def testRepr(self):
     st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
     self.assertEqual(repr(st),
                      "<StructuredTensor(fields={'a', 'b'}, shape=())>")
 def testExpandDimsAxisTooSmall(self):
   st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]]
   st = StructuredTensor.from_pyval(st)
   with self.assertRaisesRegex(ValueError,
                               "axis=-5 out of bounds: expected -4<=axis<4"):
     array_ops.expand_dims(st, -5)
Example #26
0
class StructuredTensorTest(test_util.TensorFlowTestCase,
                           parameterized.TestCase):
    def assertAllEqual(self, a, b, msg=None):
        if not (isinstance(a, structured_tensor.StructuredTensor)
                or isinstance(b, structured_tensor.StructuredTensor)):
            return super(StructuredTensorTest, self).assertAllEqual(a, b, msg)
        if not isinstance(a, structured_tensor.StructuredTensor):
            a = structured_tensor.StructuredTensor.from_pyval(a)
            self._assertStructuredEqual(a, b, msg, False)
        elif not isinstance(b, structured_tensor.StructuredTensor):
            b = structured_tensor.StructuredTensor.from_pyval(b)
            self._assertStructuredEqual(a, b, msg, False)
        else:
            self._assertStructuredEqual(a, b, msg, True)

    def _assertStructuredEqual(self, a, b, msg, check_shape):
        if check_shape:
            self.assertEqual(repr(a.shape), repr(b.shape))
        self.assertEqual(set(a.field_names()), set(b.field_names()))
        for field in a.field_names():
            a_value = a.field_value(field)
            b_value = b.field_value(field)
            self.assertIs(type(a_value), type(b_value))
            if isinstance(a_value, structured_tensor.StructuredTensor):
                self._assertStructuredEqual(a_value, b_value, msg, check_shape)
            else:
                self.assertAllEqual(a_value, b_value, msg)

    def testConstructorIsPrivate(self):
        with self.assertRaisesRegexp(
                ValueError, "StructuredTensor constructor is private"):
            structured_tensor.StructuredTensor({}, (), None, ())

    @parameterized.named_parameters([
        # Scalar (rank=0) StructuredTensors.
        {
            "testcase_name": "Rank0_WithNoFields",
            "shape": [],
            "fields": {},
        },
        {
            "testcase_name": "Rank0_WithTensorFields",
            "shape": [],
            "fields": {
                "Foo": 5,
                "Bar": [1, 2, 3]
            },
        },
        {
            "testcase_name": "Rank0_WithRaggedFields",
            "shape": [],
            "fields": {
                # note: fields have varying rank & ragged_rank.
                "p":
                ragged_factory_ops.constant_value([[1, 2], [3]]),
                "q":
                ragged_factory_ops.constant_value([[[4]], [], [[5, 6]]]),
                "r":
                ragged_factory_ops.constant_value([[[4]], [], [[5]]],
                                                  ragged_rank=1),
                "s":
                ragged_factory_ops.constant_value([[[4]], [], [[5]]],
                                                  ragged_rank=2),
            },
        },
        {
            "testcase_name": "Rank0_WithStructuredFields",
            "shape": [],
            "fields": lambda: {
                "foo":
                StructuredTensor.from_pyval({
                    "a": 1,
                    "b": [1, 2, 3]
                }),
                "bar":
                StructuredTensor.from_pyval([[{
                    "x": 12
                }], [{
                    "x": 13
                }, {
                    "x": 14
                }]]),
            },
        },
        {
            "testcase_name": "Rank0_WithMixedFields",
            "shape": [],
            "fields": lambda: {
                "f1": 5,
                "f2": [1, 2, 3],
                "f3": ragged_factory_ops.constant_value([[1, 2], [3]]),
                "f4": StructuredTensor.from_pyval({
                    "a": 1,
                    "b": [1, 2, 3]
                }),
            },
        },
        # Vector (rank=1) StructuredTensors.
        {
            "testcase_name": "Rank1_WithNoFields",
            "shape": [2],
            "fields": {},
        },
        {
            "testcase_name": "Rank1_WithExplicitNrows",
            "shape": [None],
            "nrows": 2,
            "fields": {
                "x": [1, 2],
                "y": [[1, 2], [3, 4]]
            },
            "expected_shape": [2],
        },
        {
            "testcase_name": "Rank1_WithTensorFields",
            "shape": [2],
            "fields": {
                "x": [1, 2],
                "y": [[1, 2], [3, 4]]
            },
        },
        {
            "testcase_name": "Rank1_WithRaggedFields",
            "shape": [2],
            "fields": {
                # note: fields have varying rank & ragged_rank.
                "p":
                ragged_factory_ops.constant_value([[1, 2], [3]]),
                "q":
                ragged_factory_ops.constant_value([[[4]], [[5, 6], [7]]]),
                "r":
                ragged_factory_ops.constant_value([[], [[[12]], [[13]]]]),
                "s":
                ragged_factory_ops.constant_value([[], [[[12]], [[13]]]],
                                                  ragged_rank=1),
                "t":
                ragged_factory_ops.constant_value([[], [[[12]], [[13]]]],
                                                  ragged_rank=2),
            },
        },
        {
            "testcase_name": "Rank1_WithStructuredFields",
            "shape": [2],
            "fields": lambda: {
                "foo":
                StructuredTensor.from_pyval([{
                    "a": 1,
                    "b": [1, 2, 3]
                }, {
                    "a": 2,
                    "b": []
                }]),
                "bar":
                StructuredTensor.from_pyval([[{
                    "x": 12
                }], [{
                    "x": 13
                }, {
                    "x": 14
                }]]),
            },
        },
        {
            "testcase_name": "Rank1_WithMixedFields",
            "shape": [2],
            "fields": lambda: {
                "x": [1, 2],
                "y": [[1, 2], [3, 4]],
                "r":
                ragged_factory_ops.constant_value([[1, 2], [3]]),
                "s":
                StructuredTensor.from_pyval([[{
                    "x": 12
                }], [{
                    "x": 13
                }, {
                    "x": 14
                }]]),
            },
        },
        {
            "testcase_name": "Rank1_WithNoElements",
            "shape": [0],
            "fields": lambda: {
                "x": [],
                "y": np.zeros([0, 8]),
                "r": ragged_factory_ops.constant([], ragged_rank=1),
                "s": StructuredTensor.from_pyval([]),
            },
        },
        {
            "testcase_name": "Rank1_InferDimSize",
            "shape": [None],
            "fields": lambda: {
                "x": [1, 2],
                "y": [[1, 2], [3, 4]],
                "r":
                ragged_factory_ops.constant_value([[1, 2], [3]]),
                "p":
                ragged_factory_ops.constant_value([[4], [5, 6, 7]]),
                "foo":
                StructuredTensor.from_pyval([{
                    "a": 1,
                    "b": [1, 2, 3]
                }, {
                    "a": 2,
                    "b": []
                }]),
                "bar":
                StructuredTensor.from_pyval([[{
                    "x": 12
                }], [{
                    "x": 13
                }, {
                    "x": 14
                }]]),
            },
            "expected_shape": [2],  # inferred from field values.
        },
        # Matrix (rank=2) StructuredTensors.
        {
            "testcase_name": "Rank2_WithNoFields",
            "shape": [2, 8],
            "fields": {},
        },
        {
            "testcase_name":
            "Rank2_WithNoFieldsAndExplicitRowPartitions",
            "shape": [2, None],
            "row_partitions":
            lambda: [row_partition.RowPartition.from_row_lengths([3, 7])],
            "fields": {},
        },
        {
            "testcase_name": "Rank2_WithTensorFields",
            "shape": [None, None],
            "fields": {
                "x": [[1, 2, 3], [4, 5, 6]],
                "y": np.ones([2, 3, 8])
            },
            "expected_shape": [2, 3],  # inferred from field values.
        },
        {
            "testcase_name": "Rank2_WithRaggedFields",
            "shape": [2, None],  # ragged shape = [[*, *], [*]]
            "fields": {
                # Note: fields must have identical row_splits.
                "a":
                ragged_factory_ops.constant_value([[1, 2], [3]]),
                "b":
                ragged_factory_ops.constant_value([[4, 5], [6]]),
                "c":
                ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
                "d":
                ragged_factory_ops.constant_value([[[[1, 2], [3]],
                                                    [[4], [], [5]]],
                                                   [[[6, 7, 8], []]]]),
            },
        },
        {
            "testcase_name": "Rank2_WithStructuredFields",
            "shape": [2, None],  # ragged shape = [[*], [*, *]]
            "fields": lambda: {
                # Note: fields must have identical row_splits.
                "a":
                StructuredTensor.from_pyval([[{
                    "x": 1
                }], [{
                    "x": 2
                }, {
                    "x": 3
                }]]),
                "b":
                StructuredTensor.from_pyval([[[{
                    "y": 1
                }]], [[], [{
                    "y": 2
                }, {
                    "y": 3
                }]]]),
            },
        },
        {
            "testcase_name": "Rank2_WithMixedFields",
            "shape": [2, None],
            "fields": lambda: {
                "a": [[1, 2], [3, 4]],
                "b":
                ragged_factory_ops.constant_value([[1, 2], [3, 4]]),
                "c":
                StructuredTensor.from_pyval([[[{
                    "y": 1
                }], []], [[], [{
                    "y": 2
                }, {
                    "y": 3
                }]]]),
                "d":
                ragged_factory_ops.constant_value([[[1, 2], []], [[3], [4]]]),
            },
            "expected_shape": [2, 2],
        },
        # Rank=4 StructuredTensors.
        {
            "testcase_name":
            "Rank4_WithNoFields",
            "shape": [1, None, None, 3],
            "fields": {},
            "row_partitions":
            lambda: [
                row_partition.RowPartition.from_row_lengths([3]),
                row_partition.RowPartition.from_row_lengths([2, 0, 1]),
                row_partition.RowPartition.from_uniform_row_length(3, nvals=9)
            ]
        },
        {
            "testcase_name": "Rank4_WithMixedFields",
            "shape": [1, None, None, 1],
            "fields": lambda: {
                "a":
                np.ones([1, 2, 3, 1]),
                "b":
                np.ones([1, 2, 3, 1, 5]),
                "c":
                ragged_factory_ops.constant(np.zeros([1, 2, 3, 1])),
                "d":
                ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3]).tolist(),
                                            ragged_rank=1),
                "e":
                ragged_factory_ops.constant(
                    np.zeros([1, 2, 3, 1, 2, 2]).tolist(), ragged_rank=2),
                "f":
                ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3])),
                "g":
                StructuredTensor.from_pyval([[[[{
                    "x": j,
                    "y": k
                }] for k in range(3)] for j in range(2)]]),
                "h":
                StructuredTensor.from_pyval([[[[[{
                    "x": j,
                    "y": k,
                    "z": z
                } for z in range(j)]] for k in range(3)] for j in range(2)]]),
            },
            "expected_shape": [1, 2, 3, 1],  # inferred from field values.
        },
    ])  # pyformat: disable
    def testFromFields(self,
                       shape,
                       fields,
                       expected_shape=None,
                       nrows=None,
                       row_partitions=None):
        if callable(fields):
            fields = fields(
            )  # deferred construction: fields may include tensors.
        if callable(nrows):
            nrows = nrows()  # deferred construction.
        if callable(row_partitions):
            row_partitions = row_partitions()  # deferred construction.
        for validate in (True, False):
            struct = StructuredTensor.from_fields(
                fields,
                shape,
                nrows=nrows,
                row_partitions=row_partitions,
                validate=validate)
            if expected_shape is None:
                expected_shape = shape
            self.assertEqual(struct.shape.as_list(), expected_shape)
            self.assertLen(expected_shape, struct.rank)
            self.assertCountEqual(struct.field_names(), tuple(fields.keys()))
            for field, value in fields.items():
                self.assertIsInstance(
                    struct.field_value(field),
                    (ops.Tensor, structured_tensor.StructuredTensor,
                     ragged_tensor.RaggedTensor))
                self.assertAllEqual(struct.field_value(field), value)

    @parameterized.parameters([
        dict(fields={}, shape=object(), err=TypeError),
        dict(fields=object(),
             shape=[],
             err=TypeError,
             msg="fields must be a dictionary"),
        dict(fields={1: 2},
             shape=[],
             err=TypeError,
             msg="Unexpected type for key"),
        dict(fields={"x": object()},
             shape=[],
             err=TypeError,
             msg="Unexpected type for value"),
        dict(fields={},
             shape=None,
             err=ValueError,
             msg="StructuredTensor's shape must have known rank"),
        dict(
            fields={"f": 5},
            shape=[5],
            err=ValueError,
            msg=r"Field f has shape \(\), which is incompatible with the shape "
            r"that was specified or inferred from other fields: \(5,\)"),
        dict(fields=dict(x=[1], y=[]),
             shape=[None],
             err=ValueError,
             msg=r"Field . has shape .*, which is incompatible with the shape "
             r"that was specified or inferred from other fields: .*"),
        dict(fields={"": 5},
             shape=[],
             err=ValueError,
             msg="Field name '' is not currently allowed."),
        dict(fields={"_": 5},
             shape=[],
             err=ValueError,
             msg="Field name '_' is not currently allowed."),
        dict(
            fields={
                "r1": ragged_factory_ops.constant_value([[1, 2], [3]]),
                "r2": ragged_factory_ops.constant_value([[1, 2, 3], [4]])
            },
            shape=[2, None],
            validate=True,
            err=errors.InvalidArgumentError,
            msg=r"incompatible row_splits",
        ),
        dict(fields={},
             shape=(),
             nrows=5,
             err=ValueError,
             msg="nrows must be None if shape.rank==0"),
        dict(fields={},
             shape=(),
             row_partitions=[0],
             err=ValueError,
             msg=r"row_partitions must be None or \[\] if shape.rank<2"),
        dict(fields={},
             shape=(None, None, None),
             row_partitions=[],
             err=ValueError,
             msg=r"len\(row_partitions\) must be shape.rank-1"),
        dict(fields={},
             shape=[None],
             err=ValueError,
             msg="nrows must be specified if rank==1 and `fields` is empty."),
        dict(fields={},
             shape=[None, None],
             err=ValueError,
             msg="row_partitions must be specified if rank>1 and `fields` "
             "is empty."),
        dict(fields={},
             shape=[None, None],
             nrows=lambda: constant_op.constant(2, dtypes.int32),
             row_partitions=lambda:
             [row_partition.RowPartition.from_row_lengths([3, 4])],
             err=ValueError,
             msg="field values have incompatible row_partition dtypes"),
        dict(fields=lambda: {
            "a":
            ragged_factory_ops.constant([[1]], row_splits_dtype=dtypes.int32),
            "b":
            ragged_factory_ops.constant([[1]], row_splits_dtype=dtypes.int64)
        },
             shape=[None, None],
             err=ValueError,
             msg="field values have incompatible row_partition dtypes"),
        dict(fields=lambda: {
            "a": array_ops.placeholder_with_default(np.array([1, 2, 3]), None),
            "b": array_ops.placeholder_with_default(np.array([4, 5]), None)
        },
             validate=True,
             shape=[None],
             err=(ValueError, errors.InvalidArgumentError),
             msg="fields have incompatible nrows",
             test_in_eager=False),
    ])
    def testFromFieldsErrors(self,
                             fields,
                             shape,
                             nrows=None,
                             row_partitions=None,
                             validate=False,
                             err=ValueError,
                             msg=None,
                             test_in_eager=True):
        if not test_in_eager and context.executing_eagerly():
            return
        if callable(fields):
            fields = fields()  # deferred construction.
        if callable(nrows):
            nrows = nrows()  # deferred construction.
        if callable(row_partitions):
            row_partitions = row_partitions()  # deferred construction.
        with self.assertRaisesRegexp(err, msg):
            struct = StructuredTensor.from_fields(
                fields=fields,
                shape=shape,
                nrows=nrows,
                row_partitions=row_partitions,
                validate=validate)
            for field_name in struct.field_names():
                self.evaluate(struct.field_value(field_name))
            self.evaluate(struct.nrows())

    def testMergeNrowsErrors(self):
        nrows = constant_op.constant(5)
        static_nrows = tensor_shape.Dimension(5)
        value = constant_op.constant([1, 2, 3])
        with self.assertRaisesRegexp(ValueError,
                                     "fields have incompatible nrows"):
            structured_tensor._merge_nrows(nrows,
                                           static_nrows,
                                           value,
                                           dtypes.int32,
                                           validate=False)

    def testNestedStructConstruction(self):
        rt = ragged_factory_ops.constant([[1, 2], [3]])
        struct1 = StructuredTensor.from_fields(shape=[], fields={"x": [1, 2]})
        struct2 = StructuredTensor.from_fields(shape=[2], fields={"x": [1, 2]})
        struct3 = StructuredTensor.from_fields(shape=[],
                                               fields={
                                                   "r": rt,
                                                   "s": struct1
                                               })
        struct4 = StructuredTensor.from_fields(shape=[2],
                                               fields={
                                                   "r": rt,
                                                   "s": struct2
                                               })

        self.assertEqual(struct3.shape.as_list(), [])
        self.assertEqual(struct3.rank, 0)
        self.assertEqual(set(struct3.field_names()), set(["r", "s"]))
        self.assertAllEqual(struct3.field_value("r"), rt)
        self.assertAllEqual(struct3.field_value("s"), struct1)

        self.assertEqual(struct4.shape.as_list(), [2])
        self.assertEqual(struct4.rank, 1)
        self.assertEqual(set(struct4.field_names()), set(["r", "s"]))
        self.assertAllEqual(struct4.field_value("r"), rt)
        self.assertAllEqual(struct4.field_value("s"), struct2)

    def testPartitionOuterDims(self):
        if not context.executing_eagerly(): return  # TESTING
        a = dict(x=1, y=[1, 2])
        b = dict(x=2, y=[3, 4])
        c = dict(x=3, y=[5, 6])
        d = dict(x=4, y=[7, 8])
        st1 = StructuredTensor.from_pyval([a, b, c, d])

        st2 = st1.partition_outer_dimension(
            row_partition.RowPartition.from_row_splits([0, 2, 2, 3, 4]))
        self.assertAllEqual(st2, [[a, b], [], [c], [d]])

        st3 = st2.partition_outer_dimension(
            row_partition.RowPartition.from_row_lengths([1, 0, 3, 0]))
        self.assertAllEqual(st3, [[[a, b]], [], [[], [c], [d]], []])

        # If we partition with uniform_row_lengths, then `x` is partitioned into
        # a Tensor (not a RaggedTensor).
        st4 = st1.partition_outer_dimension(
            row_partition.RowPartition.from_uniform_row_length(
                uniform_row_length=2, nvals=4, nrows=2))
        self.assertAllEqual(
            st4,
            structured_tensor.StructuredTensor.from_pyval(
                [[a, b], [c, d]],
                structured_tensor.StructuredTensorSpec(
                    [2, 2], {
                        "x":
                        tensor_spec.TensorSpec([2, 2], dtypes.int32),
                        "y":
                        ragged_tensor.RaggedTensorSpec([2, 2, None],
                                                       dtypes.int32)
                    })))

    def testPartitionOuterDimension3(self):
        rt = ragged_tensor.RaggedTensor.from_value_rowids(
            array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
        struct = structured_tensor.StructuredTensor.from_fields({"r": rt}, [2])
        struct_2 = struct.partition_outer_dimension(
            row_partition.RowPartition.from_row_splits([0, 1, 2]))
        struct_3 = struct_2.partition_outer_dimension(
            row_partition.RowPartition.from_row_splits([0, 1, 2]))
        self.assertEqual(3, struct_3.rank)

    def testPartitionOuterDimsErrors(self):
        st = StructuredTensor.from_fields({})
        partition = row_partition.RowPartition.from_row_splits([0])
        with self.assertRaisesRegexp(ValueError,
                                     r"Shape \(\) must have rank at least 1"):
            st.partition_outer_dimension(partition)

        with self.assertRaisesRegexp(TypeError,
                                     "row_partition must be a RowPartition"):
            st.partition_outer_dimension(10)

    @parameterized.named_parameters([
        {
            "testcase_name": "ScalarEmpty",
            "pyval": {},
            "expected":
            lambda: StructuredTensor.from_fields(shape=[], fields={})
        },
        {
            "testcase_name":
            "ScalarSimple",
            "pyval": {
                "a": 12,
                "b": [1, 2, 3],
                "c": [[1, 2], [3]]
            },
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[],
                fields={
                    "a": 12,
                    "b": [1, 2, 3],
                    "c": ragged_factory_ops.constant([[1, 2], [3]])
                })
        },
        {
            "testcase_name":
            "ScalarSimpleWithTypeSpec",
            "pyval": {
                "a": 12,
                "b": [1, 2, 3],
                "c": [[1, 2], [3]]
            },
            "type_spec":
            structured_tensor.StructuredTensorSpec(
                [], {
                    "a": tensor_spec.TensorSpec([], dtypes.int32),
                    "b": tensor_spec.TensorSpec([None], dtypes.int32),
                    "c": ragged_tensor.RaggedTensorSpec([None, None],
                                                        dtypes.int32)
                }),
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[],
                fields={
                    "a": 12,
                    "b": [1, 2, 3],
                    "c": ragged_factory_ops.constant([[1, 2], [3]])
                })
        },
        {
            "testcase_name":
            "ScalarWithNestedStruct",
            "pyval": {
                "a": 12,
                "b": [1, 2, 3],
                "c": {
                    "x": b"Z",
                    "y": [10, 20]
                }
            },
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[],
                fields={
                    "a":
                    12,
                    "b": [1, 2, 3],
                    "c":
                    StructuredTensor.from_fields(shape=[],
                                                 fields={
                                                     "x": "Z",
                                                     "y": [10, 20]
                                                 })
                })
        },
        {
            "testcase_name": "EmptyList",
            "pyval": [],
            "expected": lambda: [],
        },
        {
            "testcase_name": "ListOfEmptyList",
            "pyval": [[], []],
            "expected": lambda: [[], []],
        },
        {
            "testcase_name":
            "EmptyListWithTypeSpecAndFields",
            "pyval": [],
            "type_spec":
            structured_tensor.StructuredTensorSpec(
                [0], {"a": tensor_spec.TensorSpec(None, dtypes.int32)}),
            "expected":
            lambda: StructuredTensor.from_fields(shape=[0], fields={"a": []})
        },
        {
            "testcase_name":
            "EmptyListWithTypeSpecNoFieldsShape0_5",
            "pyval": [],
            "type_spec":
            structured_tensor.StructuredTensorSpec([0, 5], {}),
            "expected":
            lambda: StructuredTensor.from_fields(shape=[0, 5], fields={})
        },
        {
            "testcase_name":
            "EmptyListWithTypeSpecNoFieldsShape1_0",
            "pyval": [[]],
            "type_spec":
            structured_tensor.StructuredTensorSpec([1, 0], {}),
            "expected":
            lambda: StructuredTensor.from_fields(shape=[1, 0], fields={})
        },
        {
            "testcase_name":
            "VectorOfDict",
            "pyval": [{
                "a": 1
            }, {
                "a": 2
            }],
            "expected":
            lambda: StructuredTensor.from_fields(shape=[2],
                                                 fields={"a": [1, 2]})
        },
        {
            "testcase_name":
            "VectorOfDictWithNestedStructScalar",
            "pyval": [{
                "a": 1,
                "b": {
                    "x": [1, 2]
                }
            }, {
                "a": 2,
                "b": {
                    "x": [3]
                }
            }],
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    "a": [1, 2],
                    "b":
                    StructuredTensor.from_fields(
                        shape=[2],
                        fields=
                        {"x": ragged_factory_ops.constant([[1, 2], [3]])})
                }),
        },
        {
            "testcase_name":
            "VectorOfDictWithNestedStructVector",
            "pyval": [{
                "a": 1,
                "b": [{
                    "x": [1, 2]
                }, {
                    "x": [5]
                }]
            }, {
                "a": 2,
                "b": [{
                    "x": [3]
                }]
            }],
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    "a": [1, 2],
                    "b":
                    StructuredTensor.from_fields(
                        shape=[2, None],
                        fields={
                            "x":
                            ragged_factory_ops.constant([[[1, 2], [5]], [[3]]])
                        })
                }),
        },
        {
            "testcase_name":
            "Ragged2DOfDict",
            "pyval": [[
                {
                    "a": 1
                },
                {
                    "a": 2
                },
                {
                    "a": 3
                },
            ], [{
                "a": 4
            }, {
                "a": 5
            }]],
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[2, None],
                fields={"a": ragged_factory_ops.constant([[1, 2, 3], [4, 5]])})
        },
        {
            # With no type-spec, all tensors>1D are encoded as ragged:
            "testcase_name":
            "MatrixOfDictWithoutTypeSpec",
            "pyval": [[
                {
                    "a": 1
                },
                {
                    "a": 2
                },
                {
                    "a": 3
                },
            ], [{
                "a": 4
            }, {
                "a": 5
            }, {
                "a": 6
            }]],
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[2, None],
                fields=
                {"a": ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]])})
        },
        {
            # TypeSpec can be used to specify StructuredTensor shape.
            "testcase_name":
            "MatrixOfDictWithTypeSpec",
            "pyval": [[
                {
                    "a": 1
                },
                {
                    "a": 2
                },
                {
                    "a": 3
                },
            ], [{
                "a": 4
            }, {
                "a": 5
            }, {
                "a": 6
            }]],
            "type_spec":
            structured_tensor.StructuredTensorSpec(
                [2, 3], {"a": tensor_spec.TensorSpec(None, dtypes.int32)}),
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[2, 3], fields={"a": [[1, 2, 3], [4, 5, 6]]})
        },
    ])  # pyformat: disable
    def testPyvalConversion(self, pyval, expected, type_spec=None):
        expected = expected()  # Deferred init because it creates tensors.
        actual = structured_tensor.StructuredTensor.from_pyval(
            pyval, type_spec)
        self.assertAllEqual(actual, expected)
        if isinstance(actual, structured_tensor.StructuredTensor):
            if context.executing_eagerly(
            ):  # to_pyval only available in eager.
                self.assertEqual(actual.to_pyval(), pyval)

    @parameterized.named_parameters([
        dict(testcase_name="MissingKeys",
             pyval=[{
                 "a": [1, 2]
             }, {
                 "b": [3, 4]
             }],
             err=KeyError,
             msg="'b'"),
        dict(testcase_name="TypeSpecMismatch_DictKey",
             pyval={"a": 1},
             type_spec=structured_tensor.StructuredTensorSpec(
                 shape=[1],
                 field_specs={"b": tensor_spec.TensorSpec([], dtypes.int32)}),
             msg="Value does not match typespec"),
        dict(testcase_name="TypeSpecMismatch_ListDictKey",
             pyval=[{
                 "a": 1
             }],
             type_spec=structured_tensor.StructuredTensorSpec(
                 shape=[1],
                 field_specs={"b": tensor_spec.TensorSpec([], dtypes.int32)}),
             msg="Value does not match typespec"),
        dict(testcase_name="TypeSpecMismatch_RankMismatch",
             pyval=[{
                 "a": 1
             }],
             type_spec=structured_tensor.StructuredTensorSpec(
                 shape=[],
                 field_specs={"a": tensor_spec.TensorSpec([], dtypes.int32)}),
             msg=r"Value does not match typespec \(rank mismatch\)"),
        dict(testcase_name="TypeSpecMismatch_Scalar",
             pyval=0,
             type_spec=structured_tensor.StructuredTensorSpec(shape=[],
                                                              field_specs={}),
             msg="Value does not match typespec"),
        dict(testcase_name="TypeSpecMismatch_ListTensor",
             pyval={"a": [[1]]},
             type_spec=structured_tensor.StructuredTensorSpec(
                 shape=[],
                 field_specs={"a": tensor_spec.TensorSpec([], dtypes.int32)}),
             msg="Value does not match typespec"),
        dict(testcase_name="TypeSpecMismatch_ListSparse",
             pyval=[1, 2],
             type_spec=sparse_tensor.SparseTensorSpec([None], dtypes.int32),
             msg="Value does not match typespec"),
        dict(testcase_name="TypeSpecMismatch_ListStruct",
             pyval=[[1]],
             type_spec=structured_tensor.StructuredTensorSpec(
                 shape=[1, 1],
                 field_specs={"a": tensor_spec.TensorSpec([], dtypes.int32)}),
             msg="Value does not match typespec"),
        dict(testcase_name="InconsistentDictionaryDepth",
             pyval=[{}, [{}]],
             msg="Inconsistent depth of dictionaries"),
        dict(testcase_name="FOO",
             pyval=[[{}], 5],
             msg="Expected dict or nested list/tuple of dict"),
    ])  # pyformat: disable
    def testFromPyvalError(self,
                           pyval,
                           err=ValueError,
                           type_spec=None,
                           msg=None):
        with self.assertRaisesRegexp(err, msg):
            structured_tensor.StructuredTensor.from_pyval(pyval, type_spec)

    def testToPyvalRequiresEagerMode(self):
        st = structured_tensor.StructuredTensor.from_pyval({"a": 5})
        if not context.executing_eagerly():
            with self.assertRaisesRegexp(ValueError,
                                         "only supported in eager mode."):
                st.to_pyval()

    @parameterized.named_parameters([
        (
            "Rank0",
            [],
        ),
        (
            "Rank1",
            [5, 3],
        ),
        (
            "Rank2",
            [5, 8, 3],
        ),
        (
            "Rank5",
            [1, 2, 3, 4, 5],
        ),
    ])
    def testRowPartitionsFromUniformShape(self, shape):
        for rank in range(len(shape)):
            partitions = structured_tensor._row_partitions_for_uniform_shape(
                ops.convert_to_tensor(shape), rank)
            self.assertLen(partitions, max(0, rank - 1))
            if partitions:
                self.assertAllEqual(shape[0], partitions[0].nrows())
            for (dim, partition) in enumerate(partitions):
                self.assertAllEqual(shape[dim + 1],
                                    partition.uniform_row_length())

    @parameterized.named_parameters([
        # For shapes: U = uniform dimension; R = ragged dimension.
        dict(testcase_name="Shape_UR_Rank2",
             rt=[[1, 2], [], [3]],
             rt_ragged_rank=1,
             rank=2,
             expected_row_lengths=[[2, 0, 1]]),
        dict(testcase_name="Shape_URR_Rank2",
             rt=[[[1, 2], []], [[3]]],
             rt_ragged_rank=2,
             rank=2,
             expected_row_lengths=[[2, 1]]),
        dict(testcase_name="Shape_URU_Rank2",
             rt=[[[1], [2]], [[3]]],
             rt_ragged_rank=1,
             rank=2,
             expected_row_lengths=[[2, 1]]),
        dict(testcase_name="Shape_URR_Rank3",
             rt=[[[1, 2], []], [[3]]],
             rt_ragged_rank=2,
             rank=3,
             expected_row_lengths=[[2, 1], [2, 0, 1]]),
        dict(testcase_name="Shape_URU_Rank3",
             rt=[[[1], [2]], [[3]]],
             rt_ragged_rank=1,
             rank=3,
             expected_row_lengths=[[2, 1], [1, 1, 1]]),
        dict(testcase_name="Shape_URRUU_Rank2",
             rt=[[[[[1, 2]]]]],
             rt_ragged_rank=2,
             rank=2,
             expected_row_lengths=[[1]]),
        dict(testcase_name="Shape_URRUU_Rank3",
             rt=[[[[[1, 2]]]]],
             rt_ragged_rank=2,
             rank=3,
             expected_row_lengths=[[1], [1]]),
        dict(testcase_name="Shape_URRUU_Rank4",
             rt=[[[[[1, 2]]]]],
             rt_ragged_rank=2,
             rank=4,
             expected_row_lengths=[[1], [1], [1]]),
        dict(testcase_name="Shape_URRUU_Rank5",
             rt=[[[[[1, 2]]]]],
             rt_ragged_rank=2,
             rank=5,
             expected_row_lengths=[[1], [1], [1], [2]]),
    ])
    def testRowPartitionsForRaggedTensor(self, rt, rt_ragged_rank, rank,
                                         expected_row_lengths):
        rt = ragged_factory_ops.constant(rt, rt_ragged_rank)
        partitions = structured_tensor._row_partitions_for_ragged_tensor(
            rt, rank, dtypes.int64)
        self.assertLen(partitions, rank - 1)
        self.assertLen(partitions, len(expected_row_lengths))
        for partition, expected in zip(partitions, expected_row_lengths):
            self.assertAllEqual(partition.row_lengths(), expected)

    @parameterized.named_parameters([
        dict(testcase_name="2D_0_1",
             st=[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]],
             outer_axis=0,
             inner_axis=1,
             expected=[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }]),
        dict(testcase_name="3D_0_1",
             st=[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]],
             outer_axis=0,
             inner_axis=1,
             expected=[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }], [{
                 "x": 4
             }]]),
        dict(testcase_name="3D_1_2",
             st=[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]],
             outer_axis=1,
             inner_axis=2,
             expected=[[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }], [{
                 "x": 4
             }]]),
        dict(testcase_name="3D_0_2",
             st=[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]],
             outer_axis=0,
             inner_axis=2,
             expected=[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }, {
                 "x": 4
             }]),
        dict(testcase_name="4D_0_1",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=0,
             inner_axis=1,
             expected=[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]], [[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]),
        dict(testcase_name="4D_0_2",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=0,
             inner_axis=2,
             expected=[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }], [{
                 "x": 4
             }], [{
                 "x": 5
             }], [{
                 "x": 6
             }], [{
                 "x": 7
             }]]),
        dict(testcase_name="4D_0_3",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=0,
             inner_axis=3,
             expected=[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }, {
                 "x": 4
             }, {
                 "x": 5
             }, {
                 "x": 6
             }, {
                 "x": 7
             }]),
        dict(testcase_name="4D_1_2",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=1,
             inner_axis=2,
             expected=[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }], [{
                 "x": 4
             }]], [[{
                 "x": 5
             }], [{
                 "x": 6
             }], [{
                 "x": 7
             }]]]),
        dict(testcase_name="4D_1_3",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=1,
             inner_axis=3,
             expected=[[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }, {
                 "x": 4
             }], [{
                 "x": 5
             }, {
                 "x": 6
             }, {
                 "x": 7
             }]]),
        dict(testcase_name="4D_2_3",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=2,
             inner_axis=3,
             expected=[[[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }], [{
                 "x": 4
             }]], [[{
                 "x": 5
             }], [{
                 "x": 6
             }, {
                 "x": 7
             }]]]),
    ])  # pyformat: disable
    def testMergeDims(self, st, outer_axis, inner_axis, expected):
        st = StructuredTensor.from_pyval(st)
        result = st.merge_dims(outer_axis, inner_axis)
        self.assertAllEqual(result, expected)

    def testMergeDims_0_1(self):
        rt = ragged_tensor.RaggedTensor.from_value_rowids(
            array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
        struct = StructuredTensor.from_fields({"r": rt}, [2])
        struct_2 = struct.partition_outer_dimension(
            row_partition.RowPartition.from_row_splits([0, 1, 2]))
        struct_3 = struct_2.partition_outer_dimension(
            row_partition.RowPartition.from_row_splits([0, 1, 2]))
        self.assertLen(struct_3.row_partitions, 2)
        merged = struct_3.merge_dims(0, 1)
        self.assertLen(merged.row_partitions, 1)

    def testMergeDimsError(self):
        st = StructuredTensor.from_pyval([[[{"a": 5}]]])
        with self.assertRaisesRegexp(
                ValueError,
                r"Expected outer_axis \(2\) to be less than inner_axis \(1\)"):
            st.merge_dims(2, 1)

    def testTupleFieldValue(self):
        st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
        self.assertAllEqual(st.field_value(("a", )), 5)
        self.assertAllEqual(st.field_value(("b", "c")), [1, 2, 3])
        with self.assertRaisesRegexp(
                KeyError, r"Field path \('a', 'b'\) not found in .*"):
            st.field_value(("a", "b"))

    def testRepr(self):
        st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
        self.assertEqual(repr(st),
                         "<StructuredTensor(fields={'a', 'b'}, shape=())>")