def testFlatTensorSpecs(self):
        # Note that the batchable tensor list encoding for a StructuredTensor
        # contains a separate tensor for each leaf field.
        # In this example, _flat_tensor_specs in class StructuredTensorSpec is
        # called three times and it returns results with length 2, 3 and 11
        # for "g", "h" and `struct` respectively.
        fields = self._lambda_for_fields()
        rank = 4
        if callable(fields):
            fields = fields(
            )  # deferred construction: fields may include tensors.

        struct = StructuredTensor.from_fields_and_rank(fields, rank)
        spec = type_spec.type_spec_from_value(struct)
        flat_specs = spec._flat_tensor_specs
        self.assertEqual(
            flat_specs,
            [
                # a , b
                tensor_spec.TensorSpec(
                    shape=(1, 2, 3, 1), dtype=dtypes.float64, name=None),
                tensor_spec.TensorSpec(
                    shape=(1, 2, 3, 1, 5), dtype=dtypes.float64, name=None),
                # c, d, e, f
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                # g
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                # h
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None),
                tensor_spec.TensorSpec(
                    shape=None, dtype=dtypes.variant, name=None)
            ])
    def testFulTypesForFlatTensors(self):
        # Note that the batchable tensor list encoding for a StructuredTensor
        # contains a separate tensor for each leaf field.
        # In this example, _flat_tensor_specs in class StructuredTensorSpec is
        # called three times and it returns results with length 2, 3 and 11
        # for "g", "h" and `struct` respectively.
        fields = self._lambda_for_fields()
        rank = 4
        if callable(fields):
            fields = fields(
            )  # deferred construction: fields may include tensors.

        struct = StructuredTensor.from_fields_and_rank(fields, rank)
        spec = type_spec.type_spec_from_value(struct)
        flat_specs = spec._flat_tensor_specs
        fulltype = fulltypes_for_flat_tensors(spec)
        expected_ft_list = [
            # a, b
            full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET),
            full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET),
            # c, d, e, f
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UINT8)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT)
                ]),
            # g
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            # h
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
        ]
        self.assertEqual(len(expected_ft_list), len(flat_specs))
        self.assertEqual(fulltype, expected_ft_list)