コード例 #1
0
    def testNestedStructConstruction(self):
        rt = ragged_factory_ops.constant([[1, 2], [3]])
        struct1 = StructuredTensor.from_fields(shape=[], fields={"x": [1, 2]})
        struct2 = StructuredTensor.from_fields(shape=[2], fields={"x": [1, 2]})
        struct3 = StructuredTensor.from_fields(shape=[],
                                               fields={
                                                   "r": rt,
                                                   "s": struct1
                                               })
        struct4 = StructuredTensor.from_fields(shape=[2],
                                               fields={
                                                   "r": rt,
                                                   "s": struct2
                                               })

        self.assertEqual(struct3.shape.as_list(), [])
        self.assertEqual(struct3.rank, 0)
        self.assertEqual(set(struct3.field_names()), set(["r", "s"]))
        self.assertAllEqual(struct3.field_value("r"), rt)
        self.assertAllEqual(struct3.field_value("s"), struct1)

        self.assertEqual(struct4.shape.as_list(), [2])
        self.assertEqual(struct4.rank, 1)
        self.assertEqual(set(struct4.field_names()), set(["r", "s"]))
        self.assertAllEqual(struct4.field_value("r"), rt)
        self.assertAllEqual(struct4.field_value("s"), struct2)
コード例 #2
0
def _structured_tensor_like(t):
  """Create a StructuredTensor with the shape of a (composite) tensor."""
  if isinstance(t, ops.Tensor):
    return _structured_tensor_from_dense_tensor(t)
  if ragged_tensor.is_ragged(t):
    return StructuredTensor.from_fields(
        {}, shape=t.get_shape(), row_partitions=_all_nested_row_partitions(t))
  # here, it is a StructuredTensor
  return StructuredTensor.from_fields({},
                                      shape=t.shape,
                                      row_partitions=t.row_partitions,
                                      nrows=t.nrows())
コード例 #3
0
def _structured_tensor_from_dense_tensor(t):
  """Create a structured tensor with the shape of a dense tensor."""
  # Note: If a tensor will have rank 0,
  # it either has a fully defined shape or has unknown rank.
  if t.shape.is_fully_defined():
    return StructuredTensor.from_fields({}, shape=t.shape)
  elif t.shape.rank is None:
    raise ValueError("Can't build StructuredTensor w/ unknown rank")
  elif t.shape.rank == 1:
    return StructuredTensor.from_fields({}, shape=t.shape,
                                        nrows=array_ops.shape(t)[0])
  else:
    rt = ragged_tensor.RaggedTensor.from_tensor(t)
    return _structured_tensor_from_row_partitions(t.shape,
                                                  rt._nested_row_partitions)
コード例 #4
0
 def testToFromComponentsEmptyScalar(self):
     struct = StructuredTensor.from_fields(fields={}, shape=[])
     spec = struct._type_spec
     components = spec._to_components(struct)
     rt_reconstructed = spec._from_components(components)
     self.assertAllEqual(struct, rt_reconstructed)
     self.assertEqual(components, ({}, (), ()))
コード例 #5
0
 def testFromFieldsErrors(self,
                          fields,
                          shape,
                          nrows=None,
                          row_partitions=None,
                          validate=False,
                          err=ValueError,
                          msg=None,
                          test_in_eager=True):
     if not test_in_eager and context.executing_eagerly():
         return
     if callable(fields):
         fields = fields()  # deferred construction.
     if callable(nrows):
         nrows = nrows()  # deferred construction.
     if callable(row_partitions):
         row_partitions = row_partitions()  # deferred construction.
     with self.assertRaisesRegexp(err, msg):
         struct = StructuredTensor.from_fields(
             fields=fields,
             shape=shape,
             nrows=nrows,
             row_partitions=row_partitions,
             validate=validate)
         for field_name in struct.field_names():
             self.evaluate(struct.field_value(field_name))
         self.evaluate(struct.nrows())
コード例 #6
0
 def testFromFields(self,
                    shape,
                    fields,
                    expected_shape=None,
                    nrows=None,
                    row_partitions=None):
     if callable(fields):
         fields = fields(
         )  # deferred construction: fields may include tensors.
     if callable(nrows):
         nrows = nrows()  # deferred construction.
     if callable(row_partitions):
         row_partitions = row_partitions()  # deferred construction.
     for validate in (True, False):
         struct = StructuredTensor.from_fields(
             fields,
             shape,
             nrows=nrows,
             row_partitions=row_partitions,
             validate=validate)
         if expected_shape is None:
             expected_shape = shape
         self.assertEqual(struct.shape.as_list(), expected_shape)
         self.assertLen(expected_shape, struct.rank)
         self.assertCountEqual(struct.field_names(), tuple(fields.keys()))
         for field, value in fields.items():
             self.assertIsInstance(
                 struct.field_value(field),
                 (ops.Tensor, structured_tensor.StructuredTensor,
                  ragged_tensor.RaggedTensor))
             self.assertAllEqual(struct.field_value(field), value)
コード例 #7
0
 def testToFromComponents(self, shape, fields, field_specs):
     struct = StructuredTensor.from_fields(fields, shape)
     spec = StructuredTensor.Spec(_ragged_shape=DynamicRaggedShape.Spec(
         row_partitions=[], static_inner_shape=shape, dtype=dtypes.int64),
                                  _fields=field_specs)
     actual_components = spec._to_components(struct)
     rt_reconstructed = spec._from_components(actual_components)
     self.assertAllEqual(struct, rt_reconstructed)
コード例 #8
0
 def testToFromComponents(self, shape, fields, field_specs):
     struct = StructuredTensor.from_fields(fields, shape)
     spec = StructuredTensorSpec(shape, field_specs)
     actual_components = spec._to_components(struct)
     self.assertLen(actual_components, 3)
     self.assertAllTensorsEqual(actual_components[0], fields)
     rt_reconstructed = spec._from_components(actual_components)
     self.assertAllEqual(struct, rt_reconstructed)
コード例 #9
0
    def testPartitionOuterDimsErrors(self):
        st = StructuredTensor.from_fields({})
        partition = row_partition.RowPartition.from_row_splits([0])
        with self.assertRaisesRegexp(ValueError,
                                     r"Shape \(\) must have rank at least 1"):
            st.partition_outer_dimension(partition)

        with self.assertRaisesRegexp(TypeError,
                                     "row_partition must be a RowPartition"):
            st.partition_outer_dimension(10)
コード例 #10
0
 def testMergeDims_0_1(self):
     rt = ragged_tensor.RaggedTensor.from_value_rowids(
         array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
     struct = StructuredTensor.from_fields({"r": rt}, [2])
     struct_2 = struct.partition_outer_dimension(
         row_partition.RowPartition.from_row_splits([0, 1, 2]))
     struct_3 = struct_2.partition_outer_dimension(
         row_partition.RowPartition.from_row_splits([0, 1, 2]))
     self.assertLen(struct_3.row_partitions, 2)
     merged = struct_3.merge_dims(0, 1)
     self.assertLen(merged.row_partitions, 1)
コード例 #11
0
 def testZerosLikeObject(self, row_partitions, shape, dtype, expected):
     if row_partitions is not None:
         row_partitions = [
             row_partition.RowPartition.from_row_splits(r)
             for r in row_partitions
         ]
     st = StructuredTensor.from_fields({},
                                       shape=shape,
                                       row_partitions=row_partitions)
     # NOTE: zeros_like is very robust. There aren't arguments that
     # should cause this operation to fail.
     actual = structured_array_ops.zeros_like_object(st, dtype)
     self.assertAllEqual(actual, expected)
コード例 #12
0
  def testRank(self, row_partitions, shape, expected):
    if row_partitions is not None:
      row_partitions = [
          row_partition.RowPartition.from_row_splits(r) for r in row_partitions
      ]
    st = StructuredTensor.from_fields({},
                                      shape=shape,
                                      row_partitions=row_partitions)

    # NOTE: rank is very robust. There aren't arguments that
    # should cause this operation to fail.
    actual = structured_array_ops.rank(st)
    self.assertAllEqual(expected, actual)
コード例 #13
0
 def testToFromComponentsEmptyTensor(self):
   struct = StructuredTensor.from_fields(fields={}, shape=[1, 2, 3])
   spec = struct._type_spec
   components = spec._to_components(struct)
   rt_reconstructed = spec._from_components(components)
   self.assertAllEqual(struct, rt_reconstructed)
   self.assertLen(components, 2)
   nrows, row_partitions = components
   self.assertAllEqual(nrows, 1)
   self.assertLen(row_partitions, 2)
   self.assertIsInstance(row_partitions[0], row_partition.RowPartition)
   self.assertIsInstance(row_partitions[1], row_partition.RowPartition)
   self.assertAllEqual(row_partitions[0].row_splits(), [0, 2])
   self.assertAllEqual(row_partitions[1].row_splits(), [0, 3, 6])
コード例 #14
0
  def testSizeObject(self, row_partitions, shape, dtype, expected):
    if row_partitions is not None:
      row_partitions = [
          row_partition.RowPartition.from_row_splits(r) for r in row_partitions
      ]
    st = StructuredTensor.from_fields({},
                                      shape=shape,
                                      row_partitions=row_partitions)
    # NOTE: size is very robust. There aren't arguments that
    # should cause this operation to fail.
    actual = array_ops.size(st, out_type=dtype)
    self.assertAllEqual(actual, expected)

    actual2 = array_ops.size_v2(st, out_type=dtype)
    self.assertAllEqual(actual2, expected)
コード例 #15
0
def _extend_op(values, leaf_op, empty_st_op=None):
    """Extend an op from RaggedTensor and Tensor to StructuredTensor.

  Visits all children of the structured tensor, and children of children,
  applying leaf_op whenever it reaches a leaf, and empty_st_op whenever
  it reaches an internal node without children.

  Args:
    values: a list of structured tensors, ragged tensors, or tensors. All must
      have the same type. If they are structured tensors, they must have the
      same paths.
    leaf_op: an op for handling non-structured tensor.
    empty_st_op: op to create a structured tensor without fields.

  Returns:
    the result of the extended op (a StructuredTensor, RaggedTensor, or Tensor)

  Raises:
    ValueError:
      If values is not a Sequence or is empty.
  """
    if not isinstance(values, Sequence):
        raise ValueError('Expected a list')

    if not values:
        raise ValueError('List cannot be empty')

    if empty_st_op is None:
        empty_st_op = empty_st_op_like_zeros(leaf_op)
    # Use the structure of the first StructuredTensor. They are all assumed to
    # be the same.
    value = values[0]

    if isinstance(value, StructuredTensor):
        # TODO(martinz): Calling empty_st_op may add unnecessary ops. Revisit later.
        empty_result = empty_st_op(values)
        if not value.field_names():
            return empty_result
        new_fields = {}
        for k in value.field_names():
            new_fields[k] = _extend_op([v.field_value(k) for v in values],
                                       leaf_op, empty_st_op)
        return StructuredTensor.from_fields(new_fields,
                                            shape=empty_result.shape)
    else:
        return leaf_op(values)
コード例 #16
0
def _expand_dims_impl(st, axis, name=None):  # pylint: disable=redefined-builtin
    """Creates a StructuredTensor with a length 1 axis inserted at index `axis`.

  This is an implementation of tf.expand_dims for StructuredTensor. Note
  that the `axis` must be less than or equal to rank.

  >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
  >>> tf.expand_dims(st, 0).to_pyval()
  [[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
  >>> tf.expand_dims(st, 1).to_pyval()
  [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
  >>> tf.expand_dims(st, 2).to_pyval()
  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
  >>> tf.expand_dims(st, -1).to_pyval()  # -1 is the same as 2
  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]

  Args:
    st: the original StructuredTensor.
    axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
    name: the name of the op.

  Returns:
    a new structured tensor with larger rank.

  Raises:
    an error if `axis < -(rank + 1)` or `rank < axis`.
  """
    axis = array_ops.get_positive_axis(axis,
                                       st.rank + 1,
                                       axis_name='axis',
                                       ndims_name='rank(st)')
    with ops.name_scope(name, 'ExpandDims', [st, axis]):
        new_fields = {
            k: array_ops.expand_dims(v, axis)
            for (k, v) in st._fields.items()
        }
        new_shape = st.shape[:axis] + (1, ) + st.shape[axis:]
        new_row_partitions = _expand_st_row_partitions(st, axis)
        new_nrows = st.nrows() if (axis > 0) else 1
        return StructuredTensor.from_fields(new_fields,
                                            shape=new_shape,
                                            row_partitions=new_row_partitions,
                                            nrows=new_nrows)
コード例 #17
0
class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
                               parameterized.TestCase):

    # TODO(edloper): Add a subclass of TensorFlowTestCase that overrides
    # assertAllEqual etc to work with StructuredTensors.
    def assertAllEqual(self, a, b, msg=None):
        if not (isinstance(a, structured_tensor.StructuredTensor)
                or isinstance(b, structured_tensor.StructuredTensor)):
            return super(StructuredTensorSpecTest,
                         self).assertAllEqual(a, b, msg)
        if not (isinstance(a, structured_tensor.StructuredTensor)
                and isinstance(b, structured_tensor.StructuredTensor)):
            # TODO(edloper) Add support for this once structured_factory_ops is added.
            raise ValueError('Not supported yet')

        self.assertEqual(repr(a.shape), repr(b.shape))
        self.assertEqual(set(a.field_names()), set(b.field_names()))
        for field in a.field_names():
            self.assertAllEqual(a.field_value(field), b.field_value(field))

    def assertAllTensorsEqual(self, list1, list2):
        self.assertLen(list1, len(list2))
        for (t1, t2) in zip(list1, list2):
            self.assertAllEqual(t1, t2)

    def testConstruction(self):
        spec1_fields = dict(a=T_1_2_3_4)
        spec1 = StructuredTensorSpec([1, 2, 3], spec1_fields)
        self.assertEqual(spec1._shape, (1, 2, 3))
        self.assertEqual(spec1._field_specs, spec1_fields)

        spec2_fields = dict(a=T_1_2, b=T_1_2_8, c=R_1_N, d=R_1_N_N, s=spec1)
        spec2 = StructuredTensorSpec([1, 2], spec2_fields)
        self.assertEqual(spec2._shape, (1, 2))
        self.assertEqual(spec2._field_specs, spec2_fields)

    @parameterized.parameters([
        (None, {}, r"StructuredTensor's shape must have known rank\."),
        ([], None, r'field_specs must be a dictionary\.'),
        ([], {
            1: tensor_spec.TensorSpec(None)
        }, r'field_specs must be a dictionary with string keys\.'),
        ([], {
            'x': 0
        }, r'field_specs must be a dictionary with TypeSpec values\.'),
    ])
    def testConstructionErrors(self, shape, field_specs, error):
        with self.assertRaisesRegex(TypeError, error):
            structured_tensor.StructuredTensorSpec(shape, field_specs)

    def testValueType(self):
        spec1 = StructuredTensorSpec([1, 2, 3], dict(a=T_1_2))
        self.assertEqual(spec1.value_type, StructuredTensor)

    @parameterized.parameters([
        (StructuredTensorSpec([1, 2, 3],
                              {}), (tensor_shape.TensorShape([1, 2, 3]), {})),
        (StructuredTensorSpec([],
                              {'a': T_1_2}), (tensor_shape.TensorShape([]), {
                                  'a': T_1_2
                              })),
        (StructuredTensorSpec([1, 2], {
            'a': T_1_2,
            'b': R_1_N
        }), (tensor_shape.TensorShape([1, 2]), {
            'a': T_1_2,
            'b': R_1_N
        })),
        (StructuredTensorSpec([],
                              {'a': T_1_2}), (tensor_shape.TensorShape([]), {
                                  'a': T_1_2
                              })),
    ])  # pyformat: disable
    def testSerialize(self, spec, expected):
        serialization = spec._serialize()
        # Note that we can only use assertEqual because none of our cases include
        # a None dimension. A TensorShape with a None dimension is never equal
        # to another TensorShape.
        self.assertEqual(serialization, expected)

    @parameterized.parameters([
        (StructuredTensorSpec([1, 2, 3], {}), {}),
        (StructuredTensorSpec([], {'a': T_1_2}), {
            'a': T_1_2
        }),
        (StructuredTensorSpec([1, 2], {
            'a': T_1_2,
            'b': R_1_N
        }), {
            'a': T_1_2,
            'b': R_1_N
        }),
        (StructuredTensorSpec([], {'a': T_1_2}), {
            'a': T_1_2
        }),
    ])  # pyformat: disable
    def testComponentSpecs(self, spec, expected):
        self.assertEqual(spec._component_specs, expected)

    @parameterized.parameters([
        {
            'shape': [],
            'fields': dict(x=[[1.0, 2.0]]),
            'field_specs': dict(x=T_1_2),
        },
        # TODO(edloper): Enable this test once we update StructuredTensorSpec
        # to contain the shared row partitions.
        #{
        #    'shape': [1, 2, 3],
        #    'fields': {},
        #    'field_specs': {},
        #},
        {
            'shape': [2],
            'fields':
            dict(a=ragged_factory_ops.constant_value([[1.0], [2.0, 3.0]]),
                 b=[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
            'field_specs':
            dict(a=R_1_N, b=T_2_3),
        },
    ])  # pyformat: disable
    def testToFromComponents(self, shape, fields, field_specs):
        components = fields
        struct = StructuredTensor.from_fields(fields, shape)
        spec = StructuredTensorSpec(shape, field_specs)
        actual_components = spec._to_components(struct)
        self.assertAllTensorsEqual(actual_components, components)
        rt_reconstructed = spec._from_components(actual_components)
        self.assertAllEqual(struct, rt_reconstructed)

    @parameterized.parameters([{
        'unbatched': StructuredTensorSpec([], {}),
        'batch_size': 5,
        'batched': StructuredTensorSpec([5], {}),
    }, {
        'unbatched': StructuredTensorSpec([1, 2], {}),
        'batch_size': 5,
        'batched': StructuredTensorSpec([5, 1, 2], {}),
    }, {
        'unbatched':
        StructuredTensorSpec([], dict(a=T_3, b=R_1_N)),
        'batch_size':
        2,
        'batched':
        StructuredTensorSpec([2], dict(a=T_2_3, b=R_2_1_N)),
    }])  # pyformat: disable
    def testBatchUnbatch(self, unbatched, batch_size, batched):
        self.assertEqual(unbatched._batch(batch_size), batched)
        self.assertEqual(batched._unbatch(), unbatched)

    @parameterized.parameters([
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields({
                    'a': 1,
                    'b': [5, 6]
                }),
                StructuredTensor.from_fields({
                    'a': 2,
                    'b': [7, 8]
                })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(shape=[2],
                                                 fields={
                                                     'a': [1, 2],
                                                     'b': [[5, 6], [7, 8]]
                                                 }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(shape=[3],
                                             fields={
                                                 'a': [1, 2, 3],
                                                 'b': [[5, 6], [6, 7], [7, 8]]
                                             }),
                StructuredTensor.from_fields(shape=[3],
                                             fields={
                                                 'a': [2, 3, 4],
                                                 'b': [[2, 2], [3, 3], [4, 4]]
                                             })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2, 3],
                fields={
                    'a': [[1, 2, 3], [2, 3, 4]],
                    'b': [[[5, 6], [6, 7], [7, 8]], [[2, 2], [3, 3], [4, 4]]]
                }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'a': 1,
                        'b': StructuredTensor.from_fields({'x': [5]})
                    }),
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'a': 2,
                        'b': StructuredTensor.from_fields({'x': [6]})
                    })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    'a': [1, 2],
                    'b':
                    StructuredTensor.from_fields(shape=[2],
                                                 fields={'x': [[5], [6]]})
                }),
        },
    ])  # pyformat: disable
    def testBatchUnbatchValues(self, unbatched, batch_size, batched):
        batched = batched()  # Deferred init because it creates tensors.
        unbatched = unbatched()  # Deferred init because it creates tensors.

        # Test batching.
        unbatched_spec = type_spec.type_spec_from_value(unbatched[0])
        unbatched_tensor_lists = [
            unbatched_spec._to_tensor_list(st) for st in unbatched
        ]
        batched_tensor_list = [
            array_ops.stack(tensors)
            for tensors in zip(*unbatched_tensor_lists)
        ]
        actual_batched = unbatched_spec._batch(batch_size)._from_tensor_list(
            batched_tensor_list)
        self.assertAllEqual(actual_batched, batched)

        # Test unbatching
        batched_spec = type_spec.type_spec_from_value(batched)
        batched_tensor_list = batched_spec._to_tensor_list(batched)
        unbatched_tensor_lists = zip(
            *[array_ops.unstack(tensor) for tensor in batched_tensor_list])
        actual_unbatched = [
            batched_spec._unbatch()._from_tensor_list(tensor_list)
            for tensor_list in unbatched_tensor_lists
        ]
        self.assertLen(actual_unbatched, len(unbatched))
        for (actual, expected) in zip(actual_unbatched, unbatched):
            self.assertAllEqual(actual, expected)
コード例 #18
0
class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
                               parameterized.TestCase):

    # TODO(edloper): Add a subclass of TensorFlowTestCase that overrides
    # assertAllEqual etc to work with StructuredTensors.
    def assertAllEqual(self, a, b, msg=None):
        if not (isinstance(a, structured_tensor.StructuredTensor)
                or isinstance(b, structured_tensor.StructuredTensor)):
            return super(StructuredTensorSpecTest,
                         self).assertAllEqual(a, b, msg)
        if not (isinstance(a, structured_tensor.StructuredTensor)
                and isinstance(b, structured_tensor.StructuredTensor)):
            # TODO(edloper) Add support for this once structured_factory_ops is added.
            raise ValueError('Not supported yet')

        self.assertEqual(repr(a.shape), repr(b.shape))
        self.assertEqual(set(a.field_names()), set(b.field_names()))
        for field in a.field_names():
            self.assertAllEqual(a.field_value(field), b.field_value(field))

    def assertAllTensorsEqual(self, x, y):
        assert isinstance(x, dict) and isinstance(y, dict)
        self.assertEqual(set(x), set(y))
        for key in x:
            self.assertAllEqual(x[key], y[key])

    def testConstruction(self):
        spec1_fields = dict(a=T_1_2_3_4)
        spec1 = StructuredTensorSpec([1, 2, 3], spec1_fields)
        self.assertEqual(spec1._shape, (1, 2, 3))
        self.assertEqual(spec1._field_specs, spec1_fields)

        spec2_fields = dict(a=T_1_2, b=T_1_2_8, c=R_1_N, d=R_1_N_N, s=spec1)
        spec2 = StructuredTensorSpec([1, 2], spec2_fields)
        self.assertEqual(spec2._shape, (1, 2))
        self.assertEqual(spec2._field_specs, spec2_fields)

    @parameterized.parameters([
        (None, {}, r"StructuredTensor's shape must have known rank\."),
        ([], None, r'field_specs must be a dictionary\.'),
        ([], {
            1: tensor_spec.TensorSpec(None)
        }, r'field_specs must be a dictionary with string keys\.'),
        ([], {
            'x': 0
        }, r'field_specs must be a dictionary with TypeSpec values\.'),
    ])
    def testConstructionErrors(self, shape, field_specs, error):
        with self.assertRaisesRegex(TypeError, error):
            structured_tensor.StructuredTensorSpec(shape, field_specs)

    def testValueType(self):
        spec1 = StructuredTensorSpec([1, 2], dict(a=T_1_2))
        self.assertEqual(spec1.value_type, StructuredTensor)

    @parameterized.parameters([
        (StructuredTensorSpec([1, 2, 3],
                              {}), (tensor_shape.TensorShape([1, 2, 3]), {})),
        (StructuredTensorSpec([],
                              {'a': T_1_2}), (tensor_shape.TensorShape([]), {
                                  'a': T_1_2
                              })),
        (StructuredTensorSpec([1, 2], {
            'a': T_1_2,
            'b': R_1_N
        }), (tensor_shape.TensorShape([1, 2]), {
            'a': T_1_2,
            'b': R_1_N
        })),
        (StructuredTensorSpec([],
                              {'a': T_1_2}), (tensor_shape.TensorShape([]), {
                                  'a': T_1_2
                              })),
    ])  # pyformat: disable
    def testSerialize(self, spec, expected):
        serialization = spec._serialize()
        # Note that we can only use assertEqual because none of our cases include
        # a None dimension. A TensorShape with a None dimension is never equal
        # to another TensorShape.
        self.assertEqual(serialization, expected)

    @parameterized.parameters([
        (StructuredTensorSpec([1, 2, 3], {}),
         ({}, NROWS_SPEC, (PARTITION_SPEC, PARTITION_SPEC))),
        (StructuredTensorSpec([], {'a': T_1_2}), ({
            'a': T_1_2
        }, (), ())),
        (StructuredTensorSpec([1, 2], {
            'a': T_1_2,
            'b': R_1_N
        }), ({
            'a': T_1_2,
            'b': R_1_N
        }, NROWS_SPEC, (PARTITION_SPEC, ))),
        (StructuredTensorSpec([], {'a': T_1_2}), ({
            'a': T_1_2
        }, (), ())),
    ])  # pyformat: disable
    def testComponentSpecs(self, spec, expected):
        self.assertEqual(spec._component_specs, expected)

    @parameterized.parameters([
        {
            'shape': [],
            'fields': dict(x=[[1.0, 2.0]]),
            'field_specs': dict(x=T_1_2),
        },
        {
            'shape': [2],
            'fields':
            dict(a=ragged_factory_ops.constant_value([[1.0], [2.0, 3.0]]),
                 b=[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
            'field_specs':
            dict(a=R_1_N, b=T_2_3),
        },
    ])  # pyformat: disable
    def testToFromComponents(self, shape, fields, field_specs):
        struct = StructuredTensor.from_fields(fields, shape)
        spec = StructuredTensorSpec(shape, field_specs)
        actual_components = spec._to_components(struct)
        self.assertLen(actual_components, 3)
        self.assertAllTensorsEqual(actual_components[0], fields)
        rt_reconstructed = spec._from_components(actual_components)
        self.assertAllEqual(struct, rt_reconstructed)

    def testToFromComponentsEmptyScalar(self):
        struct = StructuredTensor.from_fields(fields={}, shape=[])
        spec = struct._type_spec
        components = spec._to_components(struct)
        rt_reconstructed = spec._from_components(components)
        self.assertAllEqual(struct, rt_reconstructed)
        self.assertEqual(components, ({}, (), ()))

    def testToFromComponentsEmptyTensor(self):
        struct = StructuredTensor.from_fields(fields={}, shape=[1, 2, 3])
        spec = struct._type_spec
        components = spec._to_components(struct)
        rt_reconstructed = spec._from_components(components)
        self.assertAllEqual(struct, rt_reconstructed)
        self.assertLen(components, 3)
        fields, nrows, row_partitions = components
        self.assertEmpty(fields)
        self.assertAllEqual(nrows, 1)
        self.assertLen(row_partitions, 2)
        self.assertIsInstance(row_partitions[0], row_partition.RowPartition)
        self.assertIsInstance(row_partitions[1], row_partition.RowPartition)
        self.assertAllEqual(row_partitions[0].row_splits(), [0, 2])
        self.assertAllEqual(row_partitions[1].row_splits(), [0, 3, 6])

    @parameterized.parameters([{
        'unbatched': StructuredTensorSpec([], {}),
        'batch_size': 5,
        'batched': StructuredTensorSpec([5], {}),
    }, {
        'unbatched': StructuredTensorSpec([1, 2], {}),
        'batch_size': 5,
        'batched': StructuredTensorSpec([5, 1, 2], {}),
    }, {
        'unbatched':
        StructuredTensorSpec([], dict(a=T_3, b=R_1_N)),
        'batch_size':
        2,
        'batched':
        StructuredTensorSpec([2], dict(a=T_2_3, b=R_2_1_N)),
    }])  # pyformat: disable
    def testBatchUnbatch(self, unbatched, batch_size, batched):
        self.assertEqual(unbatched._batch(batch_size), batched)
        self.assertEqual(batched._unbatch(), unbatched)

    @parameterized.parameters([
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields({
                    'a': 1,
                    'b': [5, 6]
                }),
                StructuredTensor.from_fields({
                    'a': 2,
                    'b': [7, 8]
                })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(shape=[2],
                                                 fields={
                                                     'a': [1, 2],
                                                     'b': [[5, 6], [7, 8]]
                                                 }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(shape=[3],
                                             fields={
                                                 'a': [1, 2, 3],
                                                 'b': [[5, 6], [6, 7], [7, 8]]
                                             }),
                StructuredTensor.from_fields(shape=[3],
                                             fields={
                                                 'a': [2, 3, 4],
                                                 'b': [[2, 2], [3, 3], [4, 4]]
                                             })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2, 3],
                fields={
                    'a': [[1, 2, 3], [2, 3, 4]],
                    'b': [[[5, 6], [6, 7], [7, 8]], [[2, 2], [3, 3], [4, 4]]]
                }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'a': 1,
                        'b': StructuredTensor.from_fields({'x': [5]})
                    }),
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'a': 2,
                        'b': StructuredTensor.from_fields({'x': [6]})
                    })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    'a': [1, 2],
                    'b':
                    StructuredTensor.from_fields(shape=[2],
                                                 fields={'x': [[5], [6]]})
                }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'Ragged3d':
                        ragged_factory_ops.constant_value([[1, 2], [3]]),
                        'Ragged2d':
                        ragged_factory_ops.constant_value([1]),
                    }),
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'Ragged3d': ragged_factory_ops.constant_value([[1]]),
                        'Ragged2d': ragged_factory_ops.constant_value([2, 3]),
                    })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    'Ragged3d':
                    ragged_factory_ops.constant_value([[[1, 2], [3]], [[1]]]),
                    'Ragged2d':
                    ragged_factory_ops.constant_value([[1], [2, 3]]),
                }),
            'use_only_batched_spec':
            True,
        },
    ])  # pyformat: disable
    def testBatchUnbatchValues(self,
                               unbatched,
                               batch_size,
                               batched,
                               use_only_batched_spec=False):
        batched = batched()  # Deferred init because it creates tensors.
        unbatched = unbatched()  # Deferred init because it creates tensors.

        # Test batching.
        if use_only_batched_spec:
            unbatched_spec = type_spec.type_spec_from_value(batched)._unbatch()
        else:
            unbatched_spec = type_spec.type_spec_from_value(unbatched[0])
        unbatched_tensor_lists = [
            unbatched_spec._to_tensor_list(st) for st in unbatched
        ]
        batched_tensor_list = [
            array_ops.stack(tensors)
            for tensors in zip(*unbatched_tensor_lists)
        ]
        actual_batched = unbatched_spec._batch(batch_size)._from_tensor_list(
            batched_tensor_list)
        self.assertTrue(
            unbatched_spec._batch(batch_size).is_compatible_with(
                actual_batched))
        self.assertAllEqual(actual_batched, batched)

        # Test unbatching
        batched_spec = type_spec.type_spec_from_value(batched)
        batched_tensor_list = batched_spec._to_batched_tensor_list(batched)
        unbatched_tensor_lists = zip(
            *[array_ops.unstack(tensor) for tensor in batched_tensor_list])
        actual_unbatched = [
            batched_spec._unbatch()._from_tensor_list(tensor_list)
            for tensor_list in unbatched_tensor_lists
        ]
        self.assertLen(actual_unbatched, len(unbatched))
        for st in actual_unbatched:
            self.assertTrue(batched_spec._unbatch().is_compatible_with(st))
        for (actual, expected) in zip(actual_unbatched, unbatched):
            self.assertAllEqual(actual, expected)

    def _lambda_for_fields(self):
        return lambda: {
            'a':
            np.ones([1, 2, 3, 1]),
            'b':
            np.ones([1, 2, 3, 1, 5]),
            'c':
            ragged_factory_ops.constant(np.zeros([1, 2, 3, 1], dtype=np.uint8),
                                        dtype=dtypes.uint8),
            'd':
            ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3]).tolist(),
                                        ragged_rank=1),
            'e':
            ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 2, 2]).tolist(),
                                        ragged_rank=2),
            'f':
            ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3]),
                                        dtype=dtypes.float32),
            'g':
            StructuredTensor.from_pyval([[
                [  # pylint: disable=g-complex-comprehension
                    [{
                        'x': j,
                        'y': k
                    }] for k in range(3)
                ] for j in range(2)
            ]]),
            'h':
            StructuredTensor.from_pyval([[
                [  # pylint: disable=g-complex-comprehension
                    [[{
                        'x': j,
                        'y': k,
                        'z': z
                    } for z in range(j)]] for k in range(3)
                ] for j in range(2)
            ]]),
        }

    def testFlatTensorSpecs(self):
        # Note that the batchable tensor list encoding for a StructuredTensor
        # contains a separate tensor for each leaf field.
        # In this example, _flat_tensor_specs in class StructuredTensorSpec is
        # called three times and it returns results with length 2, 3 and 11
        # for "g", "h" and `struct` respectively.
        fields = self._lambda_for_fields()
        rank = 4
        if callable(fields):
            fields = fields(
            )  # deferred construction: fields may include tensors.

        struct = StructuredTensor.from_fields_and_rank(fields, rank)
        spec = type_spec.type_spec_from_value(struct)
        flat_specs = spec._flat_tensor_specs
        self.assertEqual(
            flat_specs,
            [
                # a , b
                tensor_spec.TensorSpec(
                    shape=(1, 2, 3, 1), dtype=dtypes.float64, name=None),
                tensor_spec.TensorSpec(
                    shape=(1, 2, 3, 1, 5), dtype=dtypes.float64, name=None),
                # c, d, e, f
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                # g
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                # h
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None)
            ])

    def testFulTypesForFlatTensors(self):
        # Note that the batchable tensor list encoding for a StructuredTensor
        # contains a separate tensor for each leaf field.
        # In this example, _flat_tensor_specs in class StructuredTensorSpec is
        # called three times and it returns results with length 2, 3 and 11
        # for "g", "h" and `struct` respectively.
        fields = self._lambda_for_fields()
        rank = 4
        if callable(fields):
            fields = fields(
            )  # deferred construction: fields may include tensors.

        struct = StructuredTensor.from_fields_and_rank(fields, rank)
        spec = type_spec.type_spec_from_value(struct)
        flat_specs = spec._flat_tensor_specs
        fulltype = fulltypes_for_flat_tensors(spec)
        expected_ft_list = [
            # a, b
            full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET),
            full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET),
            # c, d, e, f
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UINT8)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT)
                ]),
            # g
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            # h
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
        ]
        self.assertEqual(len(expected_ft_list), len(flat_specs))
        self.assertEqual(fulltype, expected_ft_list)
コード例 #19
0
def _structured_tensor_from_row_partitions(shape, row_partitions):
  return StructuredTensor.from_fields({},
                                      shape=shape,
                                      row_partitions=row_partitions)
コード例 #20
0
class StructuredTensorTest(test_util.TensorFlowTestCase,
                           parameterized.TestCase):
    def assertAllEqual(self, a, b, msg=None):
        if not (isinstance(a, structured_tensor.StructuredTensor)
                or isinstance(b, structured_tensor.StructuredTensor)):
            return super(StructuredTensorTest, self).assertAllEqual(a, b, msg)
        if not isinstance(a, structured_tensor.StructuredTensor):
            a = structured_tensor.StructuredTensor.from_pyval(a)
            self._assertStructuredEqual(a, b, msg, False)
        elif not isinstance(b, structured_tensor.StructuredTensor):
            b = structured_tensor.StructuredTensor.from_pyval(b)
            self._assertStructuredEqual(a, b, msg, False)
        else:
            self._assertStructuredEqual(a, b, msg, True)

    def _assertStructuredEqual(self, a, b, msg, check_shape):
        if check_shape:
            self.assertEqual(repr(a.shape), repr(b.shape))
        self.assertEqual(set(a.field_names()), set(b.field_names()))
        for field in a.field_names():
            a_value = a.field_value(field)
            b_value = b.field_value(field)
            self.assertIs(type(a_value), type(b_value))
            if isinstance(a_value, structured_tensor.StructuredTensor):
                self._assertStructuredEqual(a_value, b_value, msg, check_shape)
            else:
                self.assertAllEqual(a_value, b_value, msg)

    def testConstructorIsPrivate(self):
        with self.assertRaisesRegexp(
                ValueError, "StructuredTensor constructor is private"):
            structured_tensor.StructuredTensor({}, (), None, ())

    @parameterized.named_parameters([
        # Scalar (rank=0) StructuredTensors.
        {
            "testcase_name": "Rank0_WithNoFields",
            "shape": [],
            "fields": {},
        },
        {
            "testcase_name": "Rank0_WithTensorFields",
            "shape": [],
            "fields": {
                "Foo": 5,
                "Bar": [1, 2, 3]
            },
        },
        {
            "testcase_name": "Rank0_WithRaggedFields",
            "shape": [],
            "fields": {
                # note: fields have varying rank & ragged_rank.
                "p":
                ragged_factory_ops.constant_value([[1, 2], [3]]),
                "q":
                ragged_factory_ops.constant_value([[[4]], [], [[5, 6]]]),
                "r":
                ragged_factory_ops.constant_value([[[4]], [], [[5]]],
                                                  ragged_rank=1),
                "s":
                ragged_factory_ops.constant_value([[[4]], [], [[5]]],
                                                  ragged_rank=2),
            },
        },
        {
            "testcase_name": "Rank0_WithStructuredFields",
            "shape": [],
            "fields": lambda: {
                "foo":
                StructuredTensor.from_pyval({
                    "a": 1,
                    "b": [1, 2, 3]
                }),
                "bar":
                StructuredTensor.from_pyval([[{
                    "x": 12
                }], [{
                    "x": 13
                }, {
                    "x": 14
                }]]),
            },
        },
        {
            "testcase_name": "Rank0_WithMixedFields",
            "shape": [],
            "fields": lambda: {
                "f1": 5,
                "f2": [1, 2, 3],
                "f3": ragged_factory_ops.constant_value([[1, 2], [3]]),
                "f4": StructuredTensor.from_pyval({
                    "a": 1,
                    "b": [1, 2, 3]
                }),
            },
        },
        # Vector (rank=1) StructuredTensors.
        {
            "testcase_name": "Rank1_WithNoFields",
            "shape": [2],
            "fields": {},
        },
        {
            "testcase_name": "Rank1_WithExplicitNrows",
            "shape": [None],
            "nrows": 2,
            "fields": {
                "x": [1, 2],
                "y": [[1, 2], [3, 4]]
            },
            "expected_shape": [2],
        },
        {
            "testcase_name": "Rank1_WithTensorFields",
            "shape": [2],
            "fields": {
                "x": [1, 2],
                "y": [[1, 2], [3, 4]]
            },
        },
        {
            "testcase_name": "Rank1_WithRaggedFields",
            "shape": [2],
            "fields": {
                # note: fields have varying rank & ragged_rank.
                "p":
                ragged_factory_ops.constant_value([[1, 2], [3]]),
                "q":
                ragged_factory_ops.constant_value([[[4]], [[5, 6], [7]]]),
                "r":
                ragged_factory_ops.constant_value([[], [[[12]], [[13]]]]),
                "s":
                ragged_factory_ops.constant_value([[], [[[12]], [[13]]]],
                                                  ragged_rank=1),
                "t":
                ragged_factory_ops.constant_value([[], [[[12]], [[13]]]],
                                                  ragged_rank=2),
            },
        },
        {
            "testcase_name": "Rank1_WithStructuredFields",
            "shape": [2],
            "fields": lambda: {
                "foo":
                StructuredTensor.from_pyval([{
                    "a": 1,
                    "b": [1, 2, 3]
                }, {
                    "a": 2,
                    "b": []
                }]),
                "bar":
                StructuredTensor.from_pyval([[{
                    "x": 12
                }], [{
                    "x": 13
                }, {
                    "x": 14
                }]]),
            },
        },
        {
            "testcase_name": "Rank1_WithMixedFields",
            "shape": [2],
            "fields": lambda: {
                "x": [1, 2],
                "y": [[1, 2], [3, 4]],
                "r":
                ragged_factory_ops.constant_value([[1, 2], [3]]),
                "s":
                StructuredTensor.from_pyval([[{
                    "x": 12
                }], [{
                    "x": 13
                }, {
                    "x": 14
                }]]),
            },
        },
        {
            "testcase_name": "Rank1_WithNoElements",
            "shape": [0],
            "fields": lambda: {
                "x": [],
                "y": np.zeros([0, 8]),
                "r": ragged_factory_ops.constant([], ragged_rank=1),
                "s": StructuredTensor.from_pyval([]),
            },
        },
        {
            "testcase_name": "Rank1_InferDimSize",
            "shape": [None],
            "fields": lambda: {
                "x": [1, 2],
                "y": [[1, 2], [3, 4]],
                "r":
                ragged_factory_ops.constant_value([[1, 2], [3]]),
                "p":
                ragged_factory_ops.constant_value([[4], [5, 6, 7]]),
                "foo":
                StructuredTensor.from_pyval([{
                    "a": 1,
                    "b": [1, 2, 3]
                }, {
                    "a": 2,
                    "b": []
                }]),
                "bar":
                StructuredTensor.from_pyval([[{
                    "x": 12
                }], [{
                    "x": 13
                }, {
                    "x": 14
                }]]),
            },
            "expected_shape": [2],  # inferred from field values.
        },
        # Matrix (rank=2) StructuredTensors.
        {
            "testcase_name": "Rank2_WithNoFields",
            "shape": [2, 8],
            "fields": {},
        },
        {
            "testcase_name":
            "Rank2_WithNoFieldsAndExplicitRowPartitions",
            "shape": [2, None],
            "row_partitions":
            lambda: [row_partition.RowPartition.from_row_lengths([3, 7])],
            "fields": {},
        },
        {
            "testcase_name": "Rank2_WithTensorFields",
            "shape": [None, None],
            "fields": {
                "x": [[1, 2, 3], [4, 5, 6]],
                "y": np.ones([2, 3, 8])
            },
            "expected_shape": [2, 3],  # inferred from field values.
        },
        {
            "testcase_name": "Rank2_WithRaggedFields",
            "shape": [2, None],  # ragged shape = [[*, *], [*]]
            "fields": {
                # Note: fields must have identical row_splits.
                "a":
                ragged_factory_ops.constant_value([[1, 2], [3]]),
                "b":
                ragged_factory_ops.constant_value([[4, 5], [6]]),
                "c":
                ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
                "d":
                ragged_factory_ops.constant_value([[[[1, 2], [3]],
                                                    [[4], [], [5]]],
                                                   [[[6, 7, 8], []]]]),
            },
        },
        {
            "testcase_name": "Rank2_WithStructuredFields",
            "shape": [2, None],  # ragged shape = [[*], [*, *]]
            "fields": lambda: {
                # Note: fields must have identical row_splits.
                "a":
                StructuredTensor.from_pyval([[{
                    "x": 1
                }], [{
                    "x": 2
                }, {
                    "x": 3
                }]]),
                "b":
                StructuredTensor.from_pyval([[[{
                    "y": 1
                }]], [[], [{
                    "y": 2
                }, {
                    "y": 3
                }]]]),
            },
        },
        {
            "testcase_name": "Rank2_WithMixedFields",
            "shape": [2, None],
            "fields": lambda: {
                "a": [[1, 2], [3, 4]],
                "b":
                ragged_factory_ops.constant_value([[1, 2], [3, 4]]),
                "c":
                StructuredTensor.from_pyval([[[{
                    "y": 1
                }], []], [[], [{
                    "y": 2
                }, {
                    "y": 3
                }]]]),
                "d":
                ragged_factory_ops.constant_value([[[1, 2], []], [[3], [4]]]),
            },
            "expected_shape": [2, 2],
        },
        # Rank=4 StructuredTensors.
        {
            "testcase_name":
            "Rank4_WithNoFields",
            "shape": [1, None, None, 3],
            "fields": {},
            "row_partitions":
            lambda: [
                row_partition.RowPartition.from_row_lengths([3]),
                row_partition.RowPartition.from_row_lengths([2, 0, 1]),
                row_partition.RowPartition.from_uniform_row_length(3, nvals=9)
            ]
        },
        {
            "testcase_name": "Rank4_WithMixedFields",
            "shape": [1, None, None, 1],
            "fields": lambda: {
                "a":
                np.ones([1, 2, 3, 1]),
                "b":
                np.ones([1, 2, 3, 1, 5]),
                "c":
                ragged_factory_ops.constant(np.zeros([1, 2, 3, 1])),
                "d":
                ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3]).tolist(),
                                            ragged_rank=1),
                "e":
                ragged_factory_ops.constant(
                    np.zeros([1, 2, 3, 1, 2, 2]).tolist(), ragged_rank=2),
                "f":
                ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3])),
                "g":
                StructuredTensor.from_pyval([[[[{
                    "x": j,
                    "y": k
                }] for k in range(3)] for j in range(2)]]),
                "h":
                StructuredTensor.from_pyval([[[[[{
                    "x": j,
                    "y": k,
                    "z": z
                } for z in range(j)]] for k in range(3)] for j in range(2)]]),
            },
            "expected_shape": [1, 2, 3, 1],  # inferred from field values.
        },
    ])  # pyformat: disable
    def testFromFields(self,
                       shape,
                       fields,
                       expected_shape=None,
                       nrows=None,
                       row_partitions=None):
        if callable(fields):
            fields = fields(
            )  # deferred construction: fields may include tensors.
        if callable(nrows):
            nrows = nrows()  # deferred construction.
        if callable(row_partitions):
            row_partitions = row_partitions()  # deferred construction.
        for validate in (True, False):
            struct = StructuredTensor.from_fields(
                fields,
                shape,
                nrows=nrows,
                row_partitions=row_partitions,
                validate=validate)
            if expected_shape is None:
                expected_shape = shape
            self.assertEqual(struct.shape.as_list(), expected_shape)
            self.assertLen(expected_shape, struct.rank)
            self.assertCountEqual(struct.field_names(), tuple(fields.keys()))
            for field, value in fields.items():
                self.assertIsInstance(
                    struct.field_value(field),
                    (ops.Tensor, structured_tensor.StructuredTensor,
                     ragged_tensor.RaggedTensor))
                self.assertAllEqual(struct.field_value(field), value)

    @parameterized.parameters([
        dict(fields={}, shape=object(), err=TypeError),
        dict(fields=object(),
             shape=[],
             err=TypeError,
             msg="fields must be a dictionary"),
        dict(fields={1: 2},
             shape=[],
             err=TypeError,
             msg="Unexpected type for key"),
        dict(fields={"x": object()},
             shape=[],
             err=TypeError,
             msg="Unexpected type for value"),
        dict(fields={},
             shape=None,
             err=ValueError,
             msg="StructuredTensor's shape must have known rank"),
        dict(
            fields={"f": 5},
            shape=[5],
            err=ValueError,
            msg=r"Field f has shape \(\), which is incompatible with the shape "
            r"that was specified or inferred from other fields: \(5,\)"),
        dict(fields=dict(x=[1], y=[]),
             shape=[None],
             err=ValueError,
             msg=r"Field . has shape .*, which is incompatible with the shape "
             r"that was specified or inferred from other fields: .*"),
        dict(fields={"": 5},
             shape=[],
             err=ValueError,
             msg="Field name '' is not currently allowed."),
        dict(fields={"_": 5},
             shape=[],
             err=ValueError,
             msg="Field name '_' is not currently allowed."),
        dict(
            fields={
                "r1": ragged_factory_ops.constant_value([[1, 2], [3]]),
                "r2": ragged_factory_ops.constant_value([[1, 2, 3], [4]])
            },
            shape=[2, None],
            validate=True,
            err=errors.InvalidArgumentError,
            msg=r"incompatible row_splits",
        ),
        dict(fields={},
             shape=(),
             nrows=5,
             err=ValueError,
             msg="nrows must be None if shape.rank==0"),
        dict(fields={},
             shape=(),
             row_partitions=[0],
             err=ValueError,
             msg=r"row_partitions must be None or \[\] if shape.rank<2"),
        dict(fields={},
             shape=(None, None, None),
             row_partitions=[],
             err=ValueError,
             msg=r"len\(row_partitions\) must be shape.rank-1"),
        dict(fields={},
             shape=[None],
             err=ValueError,
             msg="nrows must be specified if rank==1 and `fields` is empty."),
        dict(fields={},
             shape=[None, None],
             err=ValueError,
             msg="row_partitions must be specified if rank>1 and `fields` "
             "is empty."),
        dict(fields={},
             shape=[None, None],
             nrows=lambda: constant_op.constant(2, dtypes.int32),
             row_partitions=lambda:
             [row_partition.RowPartition.from_row_lengths([3, 4])],
             err=ValueError,
             msg="field values have incompatible row_partition dtypes"),
        dict(fields=lambda: {
            "a":
            ragged_factory_ops.constant([[1]], row_splits_dtype=dtypes.int32),
            "b":
            ragged_factory_ops.constant([[1]], row_splits_dtype=dtypes.int64)
        },
             shape=[None, None],
             err=ValueError,
             msg="field values have incompatible row_partition dtypes"),
        dict(fields=lambda: {
            "a": array_ops.placeholder_with_default(np.array([1, 2, 3]), None),
            "b": array_ops.placeholder_with_default(np.array([4, 5]), None)
        },
             validate=True,
             shape=[None],
             err=(ValueError, errors.InvalidArgumentError),
             msg="fields have incompatible nrows",
             test_in_eager=False),
    ])
    def testFromFieldsErrors(self,
                             fields,
                             shape,
                             nrows=None,
                             row_partitions=None,
                             validate=False,
                             err=ValueError,
                             msg=None,
                             test_in_eager=True):
        if not test_in_eager and context.executing_eagerly():
            return
        if callable(fields):
            fields = fields()  # deferred construction.
        if callable(nrows):
            nrows = nrows()  # deferred construction.
        if callable(row_partitions):
            row_partitions = row_partitions()  # deferred construction.
        with self.assertRaisesRegexp(err, msg):
            struct = StructuredTensor.from_fields(
                fields=fields,
                shape=shape,
                nrows=nrows,
                row_partitions=row_partitions,
                validate=validate)
            for field_name in struct.field_names():
                self.evaluate(struct.field_value(field_name))
            self.evaluate(struct.nrows())

    def testMergeNrowsErrors(self):
        nrows = constant_op.constant(5)
        static_nrows = tensor_shape.Dimension(5)
        value = constant_op.constant([1, 2, 3])
        with self.assertRaisesRegexp(ValueError,
                                     "fields have incompatible nrows"):
            structured_tensor._merge_nrows(nrows,
                                           static_nrows,
                                           value,
                                           dtypes.int32,
                                           validate=False)

    def testNestedStructConstruction(self):
        rt = ragged_factory_ops.constant([[1, 2], [3]])
        struct1 = StructuredTensor.from_fields(shape=[], fields={"x": [1, 2]})
        struct2 = StructuredTensor.from_fields(shape=[2], fields={"x": [1, 2]})
        struct3 = StructuredTensor.from_fields(shape=[],
                                               fields={
                                                   "r": rt,
                                                   "s": struct1
                                               })
        struct4 = StructuredTensor.from_fields(shape=[2],
                                               fields={
                                                   "r": rt,
                                                   "s": struct2
                                               })

        self.assertEqual(struct3.shape.as_list(), [])
        self.assertEqual(struct3.rank, 0)
        self.assertEqual(set(struct3.field_names()), set(["r", "s"]))
        self.assertAllEqual(struct3.field_value("r"), rt)
        self.assertAllEqual(struct3.field_value("s"), struct1)

        self.assertEqual(struct4.shape.as_list(), [2])
        self.assertEqual(struct4.rank, 1)
        self.assertEqual(set(struct4.field_names()), set(["r", "s"]))
        self.assertAllEqual(struct4.field_value("r"), rt)
        self.assertAllEqual(struct4.field_value("s"), struct2)

    def testPartitionOuterDims(self):
        if not context.executing_eagerly(): return  # TESTING
        a = dict(x=1, y=[1, 2])
        b = dict(x=2, y=[3, 4])
        c = dict(x=3, y=[5, 6])
        d = dict(x=4, y=[7, 8])
        st1 = StructuredTensor.from_pyval([a, b, c, d])

        st2 = st1.partition_outer_dimension(
            row_partition.RowPartition.from_row_splits([0, 2, 2, 3, 4]))
        self.assertAllEqual(st2, [[a, b], [], [c], [d]])

        st3 = st2.partition_outer_dimension(
            row_partition.RowPartition.from_row_lengths([1, 0, 3, 0]))
        self.assertAllEqual(st3, [[[a, b]], [], [[], [c], [d]], []])

        # If we partition with uniform_row_lengths, then `x` is partitioned into
        # a Tensor (not a RaggedTensor).
        st4 = st1.partition_outer_dimension(
            row_partition.RowPartition.from_uniform_row_length(
                uniform_row_length=2, nvals=4, nrows=2))
        self.assertAllEqual(
            st4,
            structured_tensor.StructuredTensor.from_pyval(
                [[a, b], [c, d]],
                structured_tensor.StructuredTensorSpec(
                    [2, 2], {
                        "x":
                        tensor_spec.TensorSpec([2, 2], dtypes.int32),
                        "y":
                        ragged_tensor.RaggedTensorSpec([2, 2, None],
                                                       dtypes.int32)
                    })))

    def testPartitionOuterDimension3(self):
        rt = ragged_tensor.RaggedTensor.from_value_rowids(
            array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
        struct = structured_tensor.StructuredTensor.from_fields({"r": rt}, [2])
        struct_2 = struct.partition_outer_dimension(
            row_partition.RowPartition.from_row_splits([0, 1, 2]))
        struct_3 = struct_2.partition_outer_dimension(
            row_partition.RowPartition.from_row_splits([0, 1, 2]))
        self.assertEqual(3, struct_3.rank)

    def testPartitionOuterDimsErrors(self):
        st = StructuredTensor.from_fields({})
        partition = row_partition.RowPartition.from_row_splits([0])
        with self.assertRaisesRegexp(ValueError,
                                     r"Shape \(\) must have rank at least 1"):
            st.partition_outer_dimension(partition)

        with self.assertRaisesRegexp(TypeError,
                                     "row_partition must be a RowPartition"):
            st.partition_outer_dimension(10)

    @parameterized.named_parameters([
        {
            "testcase_name": "ScalarEmpty",
            "pyval": {},
            "expected":
            lambda: StructuredTensor.from_fields(shape=[], fields={})
        },
        {
            "testcase_name":
            "ScalarSimple",
            "pyval": {
                "a": 12,
                "b": [1, 2, 3],
                "c": [[1, 2], [3]]
            },
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[],
                fields={
                    "a": 12,
                    "b": [1, 2, 3],
                    "c": ragged_factory_ops.constant([[1, 2], [3]])
                })
        },
        {
            "testcase_name":
            "ScalarSimpleWithTypeSpec",
            "pyval": {
                "a": 12,
                "b": [1, 2, 3],
                "c": [[1, 2], [3]]
            },
            "type_spec":
            structured_tensor.StructuredTensorSpec(
                [], {
                    "a": tensor_spec.TensorSpec([], dtypes.int32),
                    "b": tensor_spec.TensorSpec([None], dtypes.int32),
                    "c": ragged_tensor.RaggedTensorSpec([None, None],
                                                        dtypes.int32)
                }),
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[],
                fields={
                    "a": 12,
                    "b": [1, 2, 3],
                    "c": ragged_factory_ops.constant([[1, 2], [3]])
                })
        },
        {
            "testcase_name":
            "ScalarWithNestedStruct",
            "pyval": {
                "a": 12,
                "b": [1, 2, 3],
                "c": {
                    "x": b"Z",
                    "y": [10, 20]
                }
            },
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[],
                fields={
                    "a":
                    12,
                    "b": [1, 2, 3],
                    "c":
                    StructuredTensor.from_fields(shape=[],
                                                 fields={
                                                     "x": "Z",
                                                     "y": [10, 20]
                                                 })
                })
        },
        {
            "testcase_name": "EmptyList",
            "pyval": [],
            "expected": lambda: [],
        },
        {
            "testcase_name": "ListOfEmptyList",
            "pyval": [[], []],
            "expected": lambda: [[], []],
        },
        {
            "testcase_name":
            "EmptyListWithTypeSpecAndFields",
            "pyval": [],
            "type_spec":
            structured_tensor.StructuredTensorSpec(
                [0], {"a": tensor_spec.TensorSpec(None, dtypes.int32)}),
            "expected":
            lambda: StructuredTensor.from_fields(shape=[0], fields={"a": []})
        },
        {
            "testcase_name":
            "EmptyListWithTypeSpecNoFieldsShape0_5",
            "pyval": [],
            "type_spec":
            structured_tensor.StructuredTensorSpec([0, 5], {}),
            "expected":
            lambda: StructuredTensor.from_fields(shape=[0, 5], fields={})
        },
        {
            "testcase_name":
            "EmptyListWithTypeSpecNoFieldsShape1_0",
            "pyval": [[]],
            "type_spec":
            structured_tensor.StructuredTensorSpec([1, 0], {}),
            "expected":
            lambda: StructuredTensor.from_fields(shape=[1, 0], fields={})
        },
        {
            "testcase_name":
            "VectorOfDict",
            "pyval": [{
                "a": 1
            }, {
                "a": 2
            }],
            "expected":
            lambda: StructuredTensor.from_fields(shape=[2],
                                                 fields={"a": [1, 2]})
        },
        {
            "testcase_name":
            "VectorOfDictWithNestedStructScalar",
            "pyval": [{
                "a": 1,
                "b": {
                    "x": [1, 2]
                }
            }, {
                "a": 2,
                "b": {
                    "x": [3]
                }
            }],
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    "a": [1, 2],
                    "b":
                    StructuredTensor.from_fields(
                        shape=[2],
                        fields=
                        {"x": ragged_factory_ops.constant([[1, 2], [3]])})
                }),
        },
        {
            "testcase_name":
            "VectorOfDictWithNestedStructVector",
            "pyval": [{
                "a": 1,
                "b": [{
                    "x": [1, 2]
                }, {
                    "x": [5]
                }]
            }, {
                "a": 2,
                "b": [{
                    "x": [3]
                }]
            }],
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    "a": [1, 2],
                    "b":
                    StructuredTensor.from_fields(
                        shape=[2, None],
                        fields={
                            "x":
                            ragged_factory_ops.constant([[[1, 2], [5]], [[3]]])
                        })
                }),
        },
        {
            "testcase_name":
            "Ragged2DOfDict",
            "pyval": [[
                {
                    "a": 1
                },
                {
                    "a": 2
                },
                {
                    "a": 3
                },
            ], [{
                "a": 4
            }, {
                "a": 5
            }]],
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[2, None],
                fields={"a": ragged_factory_ops.constant([[1, 2, 3], [4, 5]])})
        },
        {
            # With no type-spec, all tensors>1D are encoded as ragged:
            "testcase_name":
            "MatrixOfDictWithoutTypeSpec",
            "pyval": [[
                {
                    "a": 1
                },
                {
                    "a": 2
                },
                {
                    "a": 3
                },
            ], [{
                "a": 4
            }, {
                "a": 5
            }, {
                "a": 6
            }]],
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[2, None],
                fields=
                {"a": ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]])})
        },
        {
            # TypeSpec can be used to specify StructuredTensor shape.
            "testcase_name":
            "MatrixOfDictWithTypeSpec",
            "pyval": [[
                {
                    "a": 1
                },
                {
                    "a": 2
                },
                {
                    "a": 3
                },
            ], [{
                "a": 4
            }, {
                "a": 5
            }, {
                "a": 6
            }]],
            "type_spec":
            structured_tensor.StructuredTensorSpec(
                [2, 3], {"a": tensor_spec.TensorSpec(None, dtypes.int32)}),
            "expected":
            lambda: StructuredTensor.from_fields(
                shape=[2, 3], fields={"a": [[1, 2, 3], [4, 5, 6]]})
        },
    ])  # pyformat: disable
    def testPyvalConversion(self, pyval, expected, type_spec=None):
        expected = expected()  # Deferred init because it creates tensors.
        actual = structured_tensor.StructuredTensor.from_pyval(
            pyval, type_spec)
        self.assertAllEqual(actual, expected)
        if isinstance(actual, structured_tensor.StructuredTensor):
            if context.executing_eagerly(
            ):  # to_pyval only available in eager.
                self.assertEqual(actual.to_pyval(), pyval)

    @parameterized.named_parameters([
        dict(testcase_name="MissingKeys",
             pyval=[{
                 "a": [1, 2]
             }, {
                 "b": [3, 4]
             }],
             err=KeyError,
             msg="'b'"),
        dict(testcase_name="TypeSpecMismatch_DictKey",
             pyval={"a": 1},
             type_spec=structured_tensor.StructuredTensorSpec(
                 shape=[1],
                 field_specs={"b": tensor_spec.TensorSpec([], dtypes.int32)}),
             msg="Value does not match typespec"),
        dict(testcase_name="TypeSpecMismatch_ListDictKey",
             pyval=[{
                 "a": 1
             }],
             type_spec=structured_tensor.StructuredTensorSpec(
                 shape=[1],
                 field_specs={"b": tensor_spec.TensorSpec([], dtypes.int32)}),
             msg="Value does not match typespec"),
        dict(testcase_name="TypeSpecMismatch_RankMismatch",
             pyval=[{
                 "a": 1
             }],
             type_spec=structured_tensor.StructuredTensorSpec(
                 shape=[],
                 field_specs={"a": tensor_spec.TensorSpec([], dtypes.int32)}),
             msg=r"Value does not match typespec \(rank mismatch\)"),
        dict(testcase_name="TypeSpecMismatch_Scalar",
             pyval=0,
             type_spec=structured_tensor.StructuredTensorSpec(shape=[],
                                                              field_specs={}),
             msg="Value does not match typespec"),
        dict(testcase_name="TypeSpecMismatch_ListTensor",
             pyval={"a": [[1]]},
             type_spec=structured_tensor.StructuredTensorSpec(
                 shape=[],
                 field_specs={"a": tensor_spec.TensorSpec([], dtypes.int32)}),
             msg="Value does not match typespec"),
        dict(testcase_name="TypeSpecMismatch_ListSparse",
             pyval=[1, 2],
             type_spec=sparse_tensor.SparseTensorSpec([None], dtypes.int32),
             msg="Value does not match typespec"),
        dict(testcase_name="TypeSpecMismatch_ListStruct",
             pyval=[[1]],
             type_spec=structured_tensor.StructuredTensorSpec(
                 shape=[1, 1],
                 field_specs={"a": tensor_spec.TensorSpec([], dtypes.int32)}),
             msg="Value does not match typespec"),
        dict(testcase_name="InconsistentDictionaryDepth",
             pyval=[{}, [{}]],
             msg="Inconsistent depth of dictionaries"),
        dict(testcase_name="FOO",
             pyval=[[{}], 5],
             msg="Expected dict or nested list/tuple of dict"),
    ])  # pyformat: disable
    def testFromPyvalError(self,
                           pyval,
                           err=ValueError,
                           type_spec=None,
                           msg=None):
        with self.assertRaisesRegexp(err, msg):
            structured_tensor.StructuredTensor.from_pyval(pyval, type_spec)

    def testToPyvalRequiresEagerMode(self):
        st = structured_tensor.StructuredTensor.from_pyval({"a": 5})
        if not context.executing_eagerly():
            with self.assertRaisesRegexp(ValueError,
                                         "only supported in eager mode."):
                st.to_pyval()

    @parameterized.named_parameters([
        (
            "Rank0",
            [],
        ),
        (
            "Rank1",
            [5, 3],
        ),
        (
            "Rank2",
            [5, 8, 3],
        ),
        (
            "Rank5",
            [1, 2, 3, 4, 5],
        ),
    ])
    def testRowPartitionsFromUniformShape(self, shape):
        for rank in range(len(shape)):
            partitions = structured_tensor._row_partitions_for_uniform_shape(
                ops.convert_to_tensor(shape), rank)
            self.assertLen(partitions, max(0, rank - 1))
            if partitions:
                self.assertAllEqual(shape[0], partitions[0].nrows())
            for (dim, partition) in enumerate(partitions):
                self.assertAllEqual(shape[dim + 1],
                                    partition.uniform_row_length())

    @parameterized.named_parameters([
        # For shapes: U = uniform dimension; R = ragged dimension.
        dict(testcase_name="Shape_UR_Rank2",
             rt=[[1, 2], [], [3]],
             rt_ragged_rank=1,
             rank=2,
             expected_row_lengths=[[2, 0, 1]]),
        dict(testcase_name="Shape_URR_Rank2",
             rt=[[[1, 2], []], [[3]]],
             rt_ragged_rank=2,
             rank=2,
             expected_row_lengths=[[2, 1]]),
        dict(testcase_name="Shape_URU_Rank2",
             rt=[[[1], [2]], [[3]]],
             rt_ragged_rank=1,
             rank=2,
             expected_row_lengths=[[2, 1]]),
        dict(testcase_name="Shape_URR_Rank3",
             rt=[[[1, 2], []], [[3]]],
             rt_ragged_rank=2,
             rank=3,
             expected_row_lengths=[[2, 1], [2, 0, 1]]),
        dict(testcase_name="Shape_URU_Rank3",
             rt=[[[1], [2]], [[3]]],
             rt_ragged_rank=1,
             rank=3,
             expected_row_lengths=[[2, 1], [1, 1, 1]]),
        dict(testcase_name="Shape_URRUU_Rank2",
             rt=[[[[[1, 2]]]]],
             rt_ragged_rank=2,
             rank=2,
             expected_row_lengths=[[1]]),
        dict(testcase_name="Shape_URRUU_Rank3",
             rt=[[[[[1, 2]]]]],
             rt_ragged_rank=2,
             rank=3,
             expected_row_lengths=[[1], [1]]),
        dict(testcase_name="Shape_URRUU_Rank4",
             rt=[[[[[1, 2]]]]],
             rt_ragged_rank=2,
             rank=4,
             expected_row_lengths=[[1], [1], [1]]),
        dict(testcase_name="Shape_URRUU_Rank5",
             rt=[[[[[1, 2]]]]],
             rt_ragged_rank=2,
             rank=5,
             expected_row_lengths=[[1], [1], [1], [2]]),
    ])
    def testRowPartitionsForRaggedTensor(self, rt, rt_ragged_rank, rank,
                                         expected_row_lengths):
        rt = ragged_factory_ops.constant(rt, rt_ragged_rank)
        partitions = structured_tensor._row_partitions_for_ragged_tensor(
            rt, rank, dtypes.int64)
        self.assertLen(partitions, rank - 1)
        self.assertLen(partitions, len(expected_row_lengths))
        for partition, expected in zip(partitions, expected_row_lengths):
            self.assertAllEqual(partition.row_lengths(), expected)

    @parameterized.named_parameters([
        dict(testcase_name="2D_0_1",
             st=[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]],
             outer_axis=0,
             inner_axis=1,
             expected=[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }]),
        dict(testcase_name="3D_0_1",
             st=[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]],
             outer_axis=0,
             inner_axis=1,
             expected=[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }], [{
                 "x": 4
             }]]),
        dict(testcase_name="3D_1_2",
             st=[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]],
             outer_axis=1,
             inner_axis=2,
             expected=[[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }], [{
                 "x": 4
             }]]),
        dict(testcase_name="3D_0_2",
             st=[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]],
             outer_axis=0,
             inner_axis=2,
             expected=[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }, {
                 "x": 4
             }]),
        dict(testcase_name="4D_0_1",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=0,
             inner_axis=1,
             expected=[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]], [[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]),
        dict(testcase_name="4D_0_2",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=0,
             inner_axis=2,
             expected=[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }], [{
                 "x": 4
             }], [{
                 "x": 5
             }], [{
                 "x": 6
             }], [{
                 "x": 7
             }]]),
        dict(testcase_name="4D_0_3",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=0,
             inner_axis=3,
             expected=[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }, {
                 "x": 4
             }, {
                 "x": 5
             }, {
                 "x": 6
             }, {
                 "x": 7
             }]),
        dict(testcase_name="4D_1_2",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=1,
             inner_axis=2,
             expected=[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }], [{
                 "x": 4
             }]], [[{
                 "x": 5
             }], [{
                 "x": 6
             }], [{
                 "x": 7
             }]]]),
        dict(testcase_name="4D_1_3",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=1,
             inner_axis=3,
             expected=[[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }, {
                 "x": 4
             }], [{
                 "x": 5
             }, {
                 "x": 6
             }, {
                 "x": 7
             }]]),
        dict(testcase_name="4D_2_3",
             st=[[[[{
                 "x": 1
             }, {
                 "x": 2
             }], [{
                 "x": 3
             }]], [[{
                 "x": 4
             }]]], [[[{
                 "x": 5
             }]], [[{
                 "x": 6
             }], [{
                 "x": 7
             }]]]],
             outer_axis=2,
             inner_axis=3,
             expected=[[[{
                 "x": 1
             }, {
                 "x": 2
             }, {
                 "x": 3
             }], [{
                 "x": 4
             }]], [[{
                 "x": 5
             }], [{
                 "x": 6
             }, {
                 "x": 7
             }]]]),
    ])  # pyformat: disable
    def testMergeDims(self, st, outer_axis, inner_axis, expected):
        st = StructuredTensor.from_pyval(st)
        result = st.merge_dims(outer_axis, inner_axis)
        self.assertAllEqual(result, expected)

    def testMergeDims_0_1(self):
        rt = ragged_tensor.RaggedTensor.from_value_rowids(
            array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
        struct = StructuredTensor.from_fields({"r": rt}, [2])
        struct_2 = struct.partition_outer_dimension(
            row_partition.RowPartition.from_row_splits([0, 1, 2]))
        struct_3 = struct_2.partition_outer_dimension(
            row_partition.RowPartition.from_row_splits([0, 1, 2]))
        self.assertLen(struct_3.row_partitions, 2)
        merged = struct_3.merge_dims(0, 1)
        self.assertLen(merged.row_partitions, 1)

    def testMergeDimsError(self):
        st = StructuredTensor.from_pyval([[[{"a": 5}]]])
        with self.assertRaisesRegexp(
                ValueError,
                r"Expected outer_axis \(2\) to be less than inner_axis \(1\)"):
            st.merge_dims(2, 1)

    def testTupleFieldValue(self):
        st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
        self.assertAllEqual(st.field_value(("a", )), 5)
        self.assertAllEqual(st.field_value(("b", "c")), [1, 2, 3])
        with self.assertRaisesRegexp(
                KeyError, r"Field path \('a', 'b'\) not found in .*"):
            st.field_value(("a", "b"))

    def testRepr(self):
        st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
        self.assertEqual(repr(st),
                         "<StructuredTensor(fields={'a', 'b'}, shape=())>")
コード例 #21
0
class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
                               parameterized.TestCase):

    # TODO(edloper): Add a subclass of TensorFlowTestCase that overrides
    # assertAllEqual etc to work with StructuredTensors.
    def assertAllEqual(self, a, b, msg=None):
        if not (isinstance(a, structured_tensor.StructuredTensor)
                or isinstance(b, structured_tensor.StructuredTensor)):
            return super(StructuredTensorSpecTest,
                         self).assertAllEqual(a, b, msg)
        if not (isinstance(a, structured_tensor.StructuredTensor)
                and isinstance(b, structured_tensor.StructuredTensor)):
            # TODO(edloper) Add support for this once structured_factory_ops is added.
            raise ValueError('Not supported yet')

        self.assertEqual(repr(a.shape), repr(b.shape))
        self.assertEqual(set(a.field_names()), set(b.field_names()))
        for field in a.field_names():
            self.assertAllEqual(a.field_value(field), b.field_value(field))

    def assertAllTensorsEqual(self, x, y):
        assert isinstance(x, dict) and isinstance(y, dict)
        self.assertEqual(set(x), set(y))
        for key in x:
            self.assertAllEqual(x[key], y[key])

    def testConstruction(self):
        spec1_fields = dict(a=T_1_2_3_4)
        spec1 = StructuredTensorSpec([1, 2, 3], spec1_fields)
        self.assertEqual(spec1._shape, (1, 2, 3))
        self.assertEqual(spec1._field_specs, spec1_fields)

        spec2_fields = dict(a=T_1_2, b=T_1_2_8, c=R_1_N, d=R_1_N_N, s=spec1)
        spec2 = StructuredTensorSpec([1, 2], spec2_fields)
        self.assertEqual(spec2._shape, (1, 2))
        self.assertEqual(spec2._field_specs, spec2_fields)

    @parameterized.parameters([
        (None, {}, r"StructuredTensor's shape must have known rank\."),
        ([], None, r'field_specs must be a dictionary\.'),
        ([], {
            1: tensor_spec.TensorSpec(None)
        }, r'field_specs must be a dictionary with string keys\.'),
        ([], {
            'x': 0
        }, r'field_specs must be a dictionary with TypeSpec values\.'),
    ])
    def testConstructionErrors(self, shape, field_specs, error):
        with self.assertRaisesRegex(TypeError, error):
            structured_tensor.StructuredTensorSpec(shape, field_specs)

    def testValueType(self):
        spec1 = StructuredTensorSpec([1, 2, 3], dict(a=T_1_2))
        self.assertEqual(spec1.value_type, StructuredTensor)

    @parameterized.parameters([
        (StructuredTensorSpec([1, 2, 3],
                              {}), (tensor_shape.TensorShape([1, 2, 3]), {})),
        (StructuredTensorSpec([],
                              {'a': T_1_2}), (tensor_shape.TensorShape([]), {
                                  'a': T_1_2
                              })),
        (StructuredTensorSpec([1, 2], {
            'a': T_1_2,
            'b': R_1_N
        }), (tensor_shape.TensorShape([1, 2]), {
            'a': T_1_2,
            'b': R_1_N
        })),
        (StructuredTensorSpec([],
                              {'a': T_1_2}), (tensor_shape.TensorShape([]), {
                                  'a': T_1_2
                              })),
    ])  # pyformat: disable
    def testSerialize(self, spec, expected):
        serialization = spec._serialize()
        # Note that we can only use assertEqual because none of our cases include
        # a None dimension. A TensorShape with a None dimension is never equal
        # to another TensorShape.
        self.assertEqual(serialization, expected)

    @parameterized.parameters([
        (StructuredTensorSpec([1, 2, 3], {}),
         ({}, NROWS_SPEC, (PARTITION_SPEC, PARTITION_SPEC))),
        (StructuredTensorSpec([], {'a': T_1_2}), ({
            'a': T_1_2
        }, (), ())),
        (StructuredTensorSpec([1, 2], {
            'a': T_1_2,
            'b': R_1_N
        }), ({
            'a': T_1_2,
            'b': R_1_N
        }, NROWS_SPEC, (PARTITION_SPEC, ))),
        (StructuredTensorSpec([], {'a': T_1_2}), ({
            'a': T_1_2
        }, (), ())),
    ])  # pyformat: disable
    def testComponentSpecs(self, spec, expected):
        self.assertEqual(spec._component_specs, expected)

    @parameterized.parameters([
        {
            'shape': [],
            'fields': dict(x=[[1.0, 2.0]]),
            'field_specs': dict(x=T_1_2),
        },
        {
            'shape': [2],
            'fields':
            dict(a=ragged_factory_ops.constant_value([[1.0], [2.0, 3.0]]),
                 b=[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
            'field_specs':
            dict(a=R_1_N, b=T_2_3),
        },
    ])  # pyformat: disable
    def testToFromComponents(self, shape, fields, field_specs):
        struct = StructuredTensor.from_fields(fields, shape)
        spec = StructuredTensorSpec(shape, field_specs)
        actual_components = spec._to_components(struct)
        self.assertLen(actual_components, 3)
        self.assertAllTensorsEqual(actual_components[0], fields)
        rt_reconstructed = spec._from_components(actual_components)
        self.assertAllEqual(struct, rt_reconstructed)

    def testToFromComponentsEmptyScalar(self):
        struct = StructuredTensor.from_fields(fields={}, shape=[])
        spec = struct._type_spec
        components = spec._to_components(struct)
        rt_reconstructed = spec._from_components(components)
        self.assertAllEqual(struct, rt_reconstructed)
        self.assertEqual(components, ({}, (), ()))

    def testToFromComponentsEmptyTensor(self):
        struct = StructuredTensor.from_fields(fields={}, shape=[1, 2, 3])
        spec = struct._type_spec
        components = spec._to_components(struct)
        rt_reconstructed = spec._from_components(components)
        self.assertAllEqual(struct, rt_reconstructed)
        self.assertLen(components, 3)
        fields, nrows, row_partitions = components
        self.assertEmpty(fields)
        self.assertAllEqual(nrows, 1)
        self.assertLen(row_partitions, 2)
        self.assertIsInstance(row_partitions[0], row_partition.RowPartition)
        self.assertIsInstance(row_partitions[1], row_partition.RowPartition)
        self.assertAllEqual(row_partitions[0].row_splits(), [0, 2])
        self.assertAllEqual(row_partitions[1].row_splits(), [0, 3, 6])

    @parameterized.parameters([{
        'unbatched': StructuredTensorSpec([], {}),
        'batch_size': 5,
        'batched': StructuredTensorSpec([5], {}),
    }, {
        'unbatched': StructuredTensorSpec([1, 2], {}),
        'batch_size': 5,
        'batched': StructuredTensorSpec([5, 1, 2], {}),
    }, {
        'unbatched':
        StructuredTensorSpec([], dict(a=T_3, b=R_1_N)),
        'batch_size':
        2,
        'batched':
        StructuredTensorSpec([2], dict(a=T_2_3, b=R_2_1_N)),
    }])  # pyformat: disable
    def testBatchUnbatch(self, unbatched, batch_size, batched):
        self.assertEqual(unbatched._batch(batch_size), batched)
        self.assertEqual(batched._unbatch(), unbatched)

    @parameterized.parameters([
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields({
                    'a': 1,
                    'b': [5, 6]
                }),
                StructuredTensor.from_fields({
                    'a': 2,
                    'b': [7, 8]
                })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(shape=[2],
                                                 fields={
                                                     'a': [1, 2],
                                                     'b': [[5, 6], [7, 8]]
                                                 }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(shape=[3],
                                             fields={
                                                 'a': [1, 2, 3],
                                                 'b': [[5, 6], [6, 7], [7, 8]]
                                             }),
                StructuredTensor.from_fields(shape=[3],
                                             fields={
                                                 'a': [2, 3, 4],
                                                 'b': [[2, 2], [3, 3], [4, 4]]
                                             })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2, 3],
                fields={
                    'a': [[1, 2, 3], [2, 3, 4]],
                    'b': [[[5, 6], [6, 7], [7, 8]], [[2, 2], [3, 3], [4, 4]]]
                }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'a': 1,
                        'b': StructuredTensor.from_fields({'x': [5]})
                    }),
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'a': 2,
                        'b': StructuredTensor.from_fields({'x': [6]})
                    })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    'a': [1, 2],
                    'b':
                    StructuredTensor.from_fields(shape=[2],
                                                 fields={'x': [[5], [6]]})
                }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'Ragged3d':
                        ragged_factory_ops.constant_value([[1, 2], [3]]),
                        'Ragged2d':
                        ragged_factory_ops.constant_value([1]),
                    }),
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'Ragged3d': ragged_factory_ops.constant_value([[1]]),
                        'Ragged2d': ragged_factory_ops.constant_value([2, 3]),
                    })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    'Ragged3d':
                    ragged_factory_ops.constant_value([[[1, 2], [3]], [[1]]]),
                    'Ragged2d':
                    ragged_factory_ops.constant_value([[1], [2, 3]]),
                }),
            'use_only_batched_spec':
            True,
        },
    ])  # pyformat: disable
    def testBatchUnbatchValues(self,
                               unbatched,
                               batch_size,
                               batched,
                               use_only_batched_spec=False):
        batched = batched()  # Deferred init because it creates tensors.
        unbatched = unbatched()  # Deferred init because it creates tensors.

        # Test batching.
        if use_only_batched_spec:
            unbatched_spec = type_spec.type_spec_from_value(batched)._unbatch()
        else:
            unbatched_spec = type_spec.type_spec_from_value(unbatched[0])
        unbatched_tensor_lists = [
            unbatched_spec._to_tensor_list(st) for st in unbatched
        ]
        batched_tensor_list = [
            array_ops.stack(tensors)
            for tensors in zip(*unbatched_tensor_lists)
        ]
        actual_batched = unbatched_spec._batch(batch_size)._from_tensor_list(
            batched_tensor_list)
        self.assertTrue(
            unbatched_spec._batch(batch_size).is_compatible_with(
                actual_batched))
        self.assertAllEqual(actual_batched, batched)

        # Test unbatching
        batched_spec = type_spec.type_spec_from_value(batched)
        batched_tensor_list = batched_spec._to_batched_tensor_list(batched)
        unbatched_tensor_lists = zip(
            *[array_ops.unstack(tensor) for tensor in batched_tensor_list])
        actual_unbatched = [
            batched_spec._unbatch()._from_tensor_list(tensor_list)
            for tensor_list in unbatched_tensor_lists
        ]
        self.assertLen(actual_unbatched, len(unbatched))
        for st in actual_unbatched:
            self.assertTrue(batched_spec._unbatch().is_compatible_with(st))
        for (actual, expected) in zip(actual_unbatched, unbatched):
            self.assertAllEqual(actual, expected)
コード例 #22
0
class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
                               parameterized.TestCase):

    # TODO(edloper): Add a subclass of TensorFlowTestCase that overrides
    # assertAllEqual etc to work with StructuredTensors.
    def assertAllEqual(self, a, b, msg=None):
        if not (isinstance(a, structured_tensor.StructuredTensor)
                or isinstance(b, structured_tensor.StructuredTensor)):
            return super(StructuredTensorSpecTest,
                         self).assertAllEqual(a, b, msg)
        if not (isinstance(a, structured_tensor.StructuredTensor)
                and isinstance(b, structured_tensor.StructuredTensor)):
            # TODO(edloper) Add support for this once structured_factory_ops is added.
            raise ValueError('Not supported yet')

        self.assertEqual(repr(a.shape), repr(b.shape))
        self.assertEqual(set(a.field_names()), set(b.field_names()))
        for field in a.field_names():
            self.assertAllEqual(a.field_value(field), b.field_value(field))

    def assertAllTensorsEqual(self, x, y):
        assert isinstance(x, dict) and isinstance(y, dict)
        self.assertEqual(set(x), set(y))
        for key in x:
            self.assertAllEqual(x[key], y[key])

    def testConstruction(self):
        spec1_fields = dict(a=T_1_2_3_4)
        spec1 = StructuredTensorSpec([1, 2, 3], spec1_fields)
        self.assertEqual(spec1._shape, (1, 2, 3))
        self.assertEqual(spec1._field_specs, spec1_fields)

        spec2_fields = dict(a=T_1_2, b=T_1_2_8, c=R_1_N, d=R_1_N_N, s=spec1)
        spec2 = StructuredTensorSpec([1, 2], spec2_fields)
        self.assertEqual(spec2._shape, (1, 2))
        self.assertEqual(spec2._field_specs, spec2_fields)

    @parameterized.parameters([
        (None, {}, r"StructuredTensor's shape must have known rank\."),
        ([], None, r'field_specs must be a dictionary\.'),
        ([], {
            1: tensor_spec.TensorSpec(None)
        }, r'field_specs must be a dictionary with string keys\.'),
        ([], {
            'x': 0
        }, r'field_specs must be a dictionary with TypeSpec values\.'),
    ])
    def testConstructionErrors(self, shape, field_specs, error):
        with self.assertRaisesRegex(TypeError, error):
            structured_tensor.StructuredTensorSpec(shape, field_specs)

    def testValueType(self):
        spec1 = StructuredTensorSpec([1, 2], dict(a=T_1_2))
        self.assertEqual(spec1.value_type, StructuredTensor)

    @parameterized.parameters([
        (StructuredTensorSpec([1, 2, 3], {}),
         (('_fields', {}),
          ('_ragged_shape',
           dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
               [1, 2, 3], num_row_partitions=0, dtype=dtypes.int32)))),
        (StructuredTensorSpec([], {'a': T_1_2}), (('_fields', {
            'a': T_1_2
        }), ('_ragged_shape',
             dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
                 [], num_row_partitions=0, dtype=dtypes.int64)))),
        (StructuredTensorSpec([1, 2], {
            'a': T_1_2,
            'b': R_1_N
        }), (('_fields', {
            'a': T_1_2,
            'b': R_1_N
        }), ('_ragged_shape',
             dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
                 [1, 2], num_row_partitions=1, dtype=dtypes.int64)))),
    ])  # pyformat: disable
    def testSerialize(self, spec, expected):
        serialization = spec._serialize()
        # Note that we can only use assertEqual because none of our cases include
        # a None dimension. A TensorShape with a None dimension is never equal
        # to another TensorShape.
        self.assertEqual(serialization, expected)

    @parameterized.parameters([
        (StructuredTensorSpec([1, 2, 3], {}),
         (dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
             [1, 2, 3], num_row_partitions=0, dtype=dtypes.int32), )),
        (StructuredTensorSpec([], {'a': T_1_2}), (
            tensor_spec.TensorSpec(shape=(1, 2),
                                   dtype=dtypes.float32,
                                   name=None),
            dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
                [], num_row_partitions=0, dtype=dtypes.int64),
        )),
        (StructuredTensorSpec([1, 2], {
            'a': T_1_2,
            'b': R_1_N
        }), (T_1_2, R_1_N,
             dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
                 [1, 2], num_row_partitions=1, dtype=dtypes.int64))),
    ])  # pyformat: disable
    def testComponentSpecs(self, spec, expected):
        self.assertEqual(spec._component_specs, expected)

    @parameterized.parameters([
        {
            'shape': [],
            'fields': dict(x=[[1.0, 2.0]]),
            'field_specs': dict(x=T_1_2),
        },
        {
            'shape': [2],
            'fields':
            dict(a=ragged_factory_ops.constant_value([[1.0], [2.0, 3.0]]),
                 b=[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
            'field_specs':
            dict(a=R_2_N, b=T_2_3),
        },
    ])  # pyformat: disable
    def testToFromComponents(self, shape, fields, field_specs):
        struct = StructuredTensor.from_fields(fields, shape)
        spec = StructuredTensorSpec(shape, field_specs)
        actual_components = spec._to_components(struct)
        rt_reconstructed = spec._from_components(actual_components)
        self.assertAllEqual(struct, rt_reconstructed)

    def testToFromComponentsEmptyScalar(self):
        struct = StructuredTensor.from_fields(fields={}, shape=[])
        spec = struct._type_spec
        components = spec._to_components(struct)
        rt_reconstructed = spec._from_components(components)
        self.assertAllEqual(struct, rt_reconstructed)

    def testToFromComponentsEmptyTensor(self):
        struct = StructuredTensor.from_fields(fields={}, shape=[1, 2, 3])
        spec = struct._type_spec
        components = spec._to_components(struct)
        rt_reconstructed = spec._from_components(components)
        self.assertAllEqual(struct, rt_reconstructed)

    @parameterized.parameters([{
        'unbatched': StructuredTensorSpec([], {}),
        'batch_size': 5,
        'batched': StructuredTensorSpec([5], {}),
    }, {
        'unbatched': StructuredTensorSpec([1, 2], {}),
        'batch_size': 5,
        'batched': StructuredTensorSpec([5, 1, 2], {}),
    }, {
        'unbatched':
        StructuredTensorSpec([], dict(a=T_3, b=R_1_N)),
        'batch_size':
        2,
        'batched':
        StructuredTensorSpec([2], dict(a=T_2_3, b=R_2_1_N)),
    }])  # pyformat: disable
    def testBatchUnbatch(self, unbatched, batch_size, batched):
        self.assertEqual(unbatched._batch(batch_size), batched)
        self.assertEqual(batched._unbatch(), unbatched)

    @parameterized.parameters([
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields({
                    'a': 1,
                    'b': [5, 6]
                }),
                StructuredTensor.from_fields({
                    'a': 2,
                    'b': [7, 8]
                })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(shape=[2],
                                                 fields={
                                                     'a': [1, 2],
                                                     'b': [[5, 6], [7, 8]]
                                                 }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(shape=[3],
                                             fields={
                                                 'a': [1, 2, 3],
                                                 'b': [[5, 6], [6, 7], [7, 8]]
                                             }),
                StructuredTensor.from_fields(shape=[3],
                                             fields={
                                                 'a': [2, 3, 4],
                                                 'b': [[2, 2], [3, 3], [4, 4]]
                                             })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2, 3],
                fields={
                    'a': [[1, 2, 3], [2, 3, 4]],
                    'b': [[[5, 6], [6, 7], [7, 8]], [[2, 2], [3, 3], [4, 4]]]
                }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'a': 1,
                        'b': StructuredTensor.from_fields({'x': [5]})
                    }),
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'a': 2,
                        'b': StructuredTensor.from_fields({'x': [6]})
                    })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    'a': [1, 2],
                    'b':
                    StructuredTensor.from_fields(shape=[2],
                                                 fields={'x': [[5], [6]]})
                }),
        },
        {
            'unbatched':
            lambda: [
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'Ragged3d':
                        ragged_factory_ops.constant_value([[1, 2], [3]]),
                        'Ragged2d':
                        ragged_factory_ops.constant_value([1]),
                    }),
                StructuredTensor.from_fields(
                    shape=[],
                    fields={
                        'Ragged3d': ragged_factory_ops.constant_value([[1]]),
                        'Ragged2d': ragged_factory_ops.constant_value([2, 3]),
                    })
            ],
            'batch_size':
            2,
            'batched':
            lambda: StructuredTensor.from_fields(
                shape=[2],
                fields={
                    'Ragged3d':
                    ragged_factory_ops.constant_value([[[1, 2], [3]], [[1]]]),
                    'Ragged2d':
                    ragged_factory_ops.constant_value([[1], [2, 3]]),
                }),
            'use_only_batched_spec':
            True,
        },
    ])  # pyformat: disable
    def testBatchUnbatchValues(self,
                               unbatched,
                               batch_size,
                               batched,
                               use_only_batched_spec=False):
        batched = batched()  # Deferred init because it creates tensors.
        unbatched = unbatched()  # Deferred init because it creates tensors.

        def unbatch_gen():
            for i in unbatched:
                yield i

        ds = dataset_ops.Dataset.from_tensors(batched)
        ds2 = ds.unbatch()
        if context.executing_eagerly():
            v = list(ds2.batch(2))
            self.assertAllEqual(v[0], batched)

        if not use_only_batched_spec:
            unbatched_spec = type_spec.type_spec_from_value(unbatched[0])

            dsu = dataset_ops.Dataset.from_generator(
                unbatch_gen, output_signature=unbatched_spec)
            dsu2 = dsu.batch(2)
            if context.executing_eagerly():
                v = list(dsu2)
                self.assertAllEqual(v[0], batched)

    def _lambda_for_fields(self):
        return lambda: {
            'a':
            np.ones([1, 2, 3, 1]),
            'b':
            np.ones([1, 2, 3, 1, 5]),
            'c':
            ragged_factory_ops.constant(np.zeros([1, 2, 3, 1], dtype=np.uint8),
                                        dtype=dtypes.uint8),
            'd':
            ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3]).tolist(),
                                        ragged_rank=1),
            'e':
            ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 2, 2]).tolist(),
                                        ragged_rank=2),
            'f':
            ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3]),
                                        dtype=dtypes.float32),
            'g':
            StructuredTensor.from_pyval([[
                [  # pylint: disable=g-complex-comprehension
                    [{
                        'x': j,
                        'y': k
                    }] for k in range(3)
                ] for j in range(2)
            ]]),
            'h':
            StructuredTensor.from_pyval([[
                [  # pylint: disable=g-complex-comprehension
                    [[{
                        'x': j,
                        'y': k,
                        'z': z
                    } for z in range(j)]] for k in range(3)
                ] for j in range(2)
            ]]),
        }