def testSerializedContainingRaggedFeatureWithNoPartitions(self):
    original = [
        example(
            features=features({
                "rt_c": float_feature([3, 4, 5, 6, 7, 8]),
                "rt_f_values": float_feature([0, 1, 2, 3, 4]),
            })),
        example(
            features=features({
                "rt_c": float_feature([]),  # empty float list
            })),
        example(
            features=features({
                "rt_d": feature(),  # feature with nothing in it
            })),
        example(
            features=features({
                "rt_c": float_feature([1, 2, -1]),
                "rt_d": bytes_feature([b"hi"]),
                "rt_f_values": float_feature([0, 1, 2]),
            }))
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_rt_c = ragged_factory_ops.constant_value(
        [[3.0, 4.0, 5.0, 6.0, 7.0, 8.0], [], [], [1.0, 2.0, -1.0]],
        row_splits_dtype=dtypes.int32)
    expected_rt_d = ragged_factory_ops.constant_value(
        [[], [], [], [b"hi"]], row_splits_dtype=dtypes.int64)
    expected_rt_f = ragged_factory_ops.constant_value(
        [[0.0, 1.0, 2.0, 3.0, 4.0], [], [], [0.0, 1.0, 2.0]],
        row_splits_dtype=dtypes.int32)

    expected_output = {
        "rt_c": expected_rt_c,
        "rt_d": expected_rt_d,
        "rt_f": expected_rt_f,
    }

    self._test(
        ops.convert_to_tensor(serialized), {
            "rt_c":
                parsing_ops.RaggedFeature(dtypes.float32),
            "rt_d":
                parsing_ops.RaggedFeature(
                    dtypes.string, row_splits_dtype=dtypes.int64),
            "rt_f":
                parsing_ops.RaggedFeature(
                    dtypes.float32, value_key="rt_f_values"),
        },
        expected_values=expected_output,
        create_iterator_twice=True)
  def testSerializedContainingRaggedFeatureWithMultiplePartitions(self):
    original = [
        # rt shape: [(batch), 2, None, None]
        example(
            features=features({
                # rt = [[[[1]], [[2, 3], [4]]], [[], [[5, 6, 7]]]]
                "rt_values": float_feature([1, 2, 3, 4, 5, 6, 7]),
                "lengths_axis2": int64_feature([1, 2, 0, 1]),
                "lengths_axis3": int64_feature([1, 2, 1, 3]),
                "splits_axis3": int64_feature([0, 1, 3, 4, 7]),
            })),
        example(
            features=features({
                # rt = [[[[1, 2, 3], [4]], [[5], [6], [7, 8]]]]
                "rt_values": float_feature([1, 2, 3, 4, 5, 6, 7, 8]),
                "lengths_axis2": int64_feature([2, 3]),
                "lengths_axis3": int64_feature([3, 1, 1, 1, 2]),
                "splits_axis3": int64_feature([0, 3, 4, 5, 6, 8]),
            }))
    ]
    serialized = [m.SerializeToString() for m in original]

    test_features = {
        "rt1":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[
                    parsing_ops.RaggedFeature.UniformRowLength(2),
                    parsing_ops.RaggedFeature.RowLengths("lengths_axis2"),
                    parsing_ops.RaggedFeature.RowSplits("splits_axis3"),
                ],
                dtype=dtypes.float32,
                row_splits_dtype=dtypes.int64,
            ),
    }

    expected_rt = ragged_factory_ops.constant(
        [[[[[1]], [[2, 3], [4]]], [[], [[5, 6, 7]]]],
         [[[[1, 2, 3], [4]], [[5], [6], [7, 8]]]]],
        dtype=dtypes.float32,
        row_splits_dtype=dtypes.int64)

    expected_output = {
        "rt1": expected_rt,
    }

    self._test(
        ops.convert_to_tensor(serialized),
        test_features,
        expected_values=expected_output,
        create_iterator_twice=True)
  def testSerializedContainingRaggedFeatureWithOnePartition(self):
    original = [
        example(
            features=features({
                # rt = [[3], [4, 5, 6]]
                "rt_values": float_feature([3, 4, 5, 6]),
                "rt_splits": int64_feature([0, 1, 4]),
                "rt_lengths": int64_feature([1, 3]),
                "rt_starts": int64_feature([0, 1]),
                "rt_limits": int64_feature([1, 4]),
                "rt_rowids": int64_feature([0, 1, 1, 1]),
            })),
        example(
            features=features({
                # rt = []
                "rt_values": float_feature([]),
                "rt_splits": int64_feature([0]),
                "rt_lengths": int64_feature([]),
                "rt_starts": int64_feature([]),
                "rt_limits": int64_feature([]),
                "rt_rowids": int64_feature([]),
            })),
        example(
            features=features({
                # rt = []
                "rt_values": feature(),  # feature with nothing in it
                "rt_splits": int64_feature([0]),
                "rt_lengths": feature(),
                "rt_starts": feature(),
                "rt_limits": feature(),
                "rt_rowids": feature(),
            })),
        example(
            features=features({
                # rt = [[1.0, 2.0, -1.0], [], [8.0, 9.0], [5.0]]
                "rt_values": float_feature([1, 2, -1, 8, 9, 5]),
                "rt_splits": int64_feature([0, 3, 3, 5, 6]),
                "rt_lengths": int64_feature([3, 0, 2, 1]),
                "rt_starts": int64_feature([0, 3, 3, 5]),
                "rt_limits": int64_feature([3, 3, 5, 6]),
                "rt_rowids": int64_feature([0, 0, 0, 2, 2, 3]),
            }))
    ]
    serialized = [m.SerializeToString() for m in original]

    test_features = {
        "rt1":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.RowSplits("rt_splits")],
                dtype=dtypes.float32),
        "rt2":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.RowLengths("rt_lengths")],
                dtype=dtypes.float32),
        "rt3":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.RowStarts("rt_starts")],
                dtype=dtypes.float32),
        "rt4":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.RowLimits("rt_limits")],
                dtype=dtypes.float32),
        "rt5":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.ValueRowIds("rt_rowids")],
                dtype=dtypes.float32),
        "uniform1":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.UniformRowLength(2)],
                dtype=dtypes.float32),
        "uniform2":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[
                    parsing_ops.RaggedFeature.UniformRowLength(2),
                    parsing_ops.RaggedFeature.RowSplits("rt_splits")
                ],
                dtype=dtypes.float32),
    }

    expected_rt = ragged_factory_ops.constant(
        [[[3], [4, 5, 6]], [], [], [[1, 2, -1], [], [8, 9], [5]]],
        dtype=dtypes.float32,
        row_splits_dtype=dtypes.int32)

    expected_uniform1 = ragged_factory_ops.constant(
        [[[3, 4], [5, 6]], [], [], [[1, 2], [-1, 8], [9, 5]]],
        ragged_rank=1,
        dtype=dtypes.float32,
        row_splits_dtype=dtypes.int32)

    expected_uniform2 = ragged_factory_ops.constant(
        [[[[3], [4, 5, 6]]], [], [], [[[1, 2, -1], []], [[8, 9], [5]]]],
        dtype=dtypes.float32,
        row_splits_dtype=dtypes.int32)

    expected_output = {
        "rt1": expected_rt,
        "rt2": expected_rt,
        "rt3": expected_rt,
        "rt4": expected_rt,
        "rt5": expected_rt,
        "uniform1": expected_uniform1,
        "uniform2": expected_uniform2,
    }

    self._test(
        ops.convert_to_tensor(serialized),
        test_features,
        expected_values=expected_output,
        create_iterator_twice=True)