def testFromRowLimits(self):
    row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64)

    rp = RowPartition.from_row_limits(row_limits, validate=False)
    self.assertEqual(rp.dtype, dtypes.int64)

    rp_row_limits = rp.row_limits()
    rp_row_splits = rp.row_splits()
    rp_nrows = rp.nrows()

    self.assertAllEqual(rp_nrows, 5)
    self.assertAllEqual(rp_row_limits, row_limits)
    self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])
  def testClassDocStringExamples(self):
    # From section: "Component Tensors"
    rp = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
    self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8])
    del rp

    # From section: "Alternative Row-Partitioning Schemes"
    rt1 = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
    rt2 = RowPartition.from_row_lengths(row_lengths=[4, 0, 3, 1, 0])
    rt3 = RowPartition.from_value_rowids(
        value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
    rt4 = RowPartition.from_row_starts(row_starts=[0, 4, 4, 7, 8], nvals=8)
    rt5 = RowPartition.from_row_limits(row_limits=[4, 4, 7, 8, 8])
    for rp in (rt1, rt2, rt3, rt4, rt5):
      self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8])
    del rt1, rt2, rt3, rt4, rt5

    # From section: "Multiple Ragged Dimensions"
    inner_rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
    outer_rt = RowPartition.from_row_splits(row_splits=[0, 3, 3, 5])
    del inner_rt, outer_rt
class RowPartitionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
  #=============================================================================
  # RaggedTensor class docstring examples
  #=============================================================================

  def testClassDocStringExamples(self):
    # From section: "Component Tensors"
    rp = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
    self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8])
    del rp

    # From section: "Alternative Row-Partitioning Schemes"
    rt1 = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
    rt2 = RowPartition.from_row_lengths(row_lengths=[4, 0, 3, 1, 0])
    rt3 = RowPartition.from_value_rowids(
        value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
    rt4 = RowPartition.from_row_starts(row_starts=[0, 4, 4, 7, 8], nvals=8)
    rt5 = RowPartition.from_row_limits(row_limits=[4, 4, 7, 8, 8])
    for rp in (rt1, rt2, rt3, rt4, rt5):
      self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8])
    del rt1, rt2, rt3, rt4, rt5

    # From section: "Multiple Ragged Dimensions"
    inner_rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
    outer_rt = RowPartition.from_row_splits(row_splits=[0, 3, 3, 5])
    del inner_rt, outer_rt

  #=============================================================================
  # RaggedTensor Constructor (private)
  #=============================================================================

  def testRaggedTensorConstruction(self):
    row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
    rp = RowPartition(
        row_splits=row_splits,
        internal=row_partition._row_partition_factory_key)
    self.assertAllEqual(rp.row_splits(), [0, 2, 2, 5, 6, 7])

  def testRaggedTensorConstructionErrors(self):
    row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)

    with self.assertRaisesRegex(ValueError,
                                'RaggedTensor constructor is private'):
      RowPartition(row_splits=row_splits)

    with self.assertRaisesRegex(TypeError,
                                'Row-partitioning argument must be a Tensor'):
      RowPartition(
          row_splits=[0, 2, 2, 5, 6, 7],
          internal=row_partition._row_partition_factory_key)

    with self.assertRaisesRegex(ValueError, r'Shape \(6, 1\) must have rank 1'):
      RowPartition(
          row_splits=array_ops.expand_dims(row_splits, 1),
          internal=row_partition._row_partition_factory_key)

    with self.assertRaisesRegex(TypeError,
                                'Cached value must be a Tensor or None.'):
      RowPartition(
          row_splits=row_splits,
          row_lengths=[2, 3, 4],
          internal=row_partition._row_partition_factory_key)

  #=============================================================================
  # RaggedTensor Factory Ops
  #=============================================================================

  def testFromValueRowIdsWithDerivedNRows(self):
    # nrows is known at graph creation time.
    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
    # TODO(martinz): add nrows
    rp = RowPartition.from_value_rowids(value_rowids, validate=False)
    self.assertEqual(rp.dtype, dtypes.int64)

    rp_row_splits = rp.row_splits()
    rp_value_rowids = rp.value_rowids()
    rp_nrows = rp.nrows()

    self.assertIs(rp_value_rowids, value_rowids)  # value_rowids
    self.assertAllEqual(rp_value_rowids, value_rowids)
    self.assertAllEqual(rp_nrows, 5)
    self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])

  def testFromValueRowIdsWithDerivedNRowsDynamic(self):
    # nrows is not known at graph creation time.
    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
    value_rowids = array_ops.placeholder_with_default(value_rowids, shape=None)

    rp = RowPartition.from_value_rowids(value_rowids, validate=False)

    rp_value_rowids = rp.value_rowids()
    rp_nrows = rp.nrows()

    self.assertIs(rp_value_rowids, value_rowids)  # value_rowids
    self.assertAllEqual(rp_value_rowids, value_rowids)
    self.assertAllEqual(rp_nrows, 5)

  def testFromValueRowIdsWithExplicitNRows(self):
    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
    nrows = constant_op.constant(7, dtypes.int64)

    rp = RowPartition.from_value_rowids(value_rowids, nrows, validate=False)

    rp_value_rowids = rp.value_rowids()
    rp_nrows = rp.nrows()
    rp_row_splits = rp.row_splits()

    self.assertIs(rp_value_rowids, value_rowids)  # value_rowids
    self.assertIs(rp_nrows, nrows)  # nrows
    self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7, 7, 7])

  def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
    nrows = constant_op.constant(5, dtypes.int64)

    rp = RowPartition.from_value_rowids(value_rowids, nrows, validate=False)

    rp_value_rowids = rp.value_rowids()
    rp_nrows = rp.nrows()
    rp_row_splits = rp.row_splits()

    self.assertIs(rp_value_rowids, value_rowids)  # value_rowids
    self.assertIs(rp_nrows, nrows)  # nrows
    self.assertAllEqual(rp_value_rowids, value_rowids)
    self.assertAllEqual(rp_nrows, nrows)
    self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])

  def testFromValueRowIdsWithEmptyValues(self):
    rp = RowPartition.from_value_rowids([])
    rp_nrows = rp.nrows()
    self.assertEqual(rp.dtype, dtypes.int64)
    self.assertEqual(rp.value_rowids().shape.as_list(), [0])
    self.assertAllEqual(rp_nrows, 0)

  def testFromRowSplits(self):
    row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)

    rp = RowPartition.from_row_splits(row_splits, validate=False)
    self.assertEqual(rp.dtype, dtypes.int64)

    rp_row_splits = rp.row_splits()
    rp_nrows = rp.nrows()

    self.assertIs(rp_row_splits, row_splits)
    self.assertAllEqual(rp_nrows, 5)

  def testFromRowSplitsWithDifferentSplitTypes(self):
    splits1 = [0, 2, 2, 5, 6, 7]
    splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64)
    splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32)
    splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
    splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32)
    rt1 = RowPartition.from_row_splits(splits1)
    rt2 = RowPartition.from_row_splits(splits2)
    rt3 = RowPartition.from_row_splits(splits3)
    rt4 = RowPartition.from_row_splits(splits4)
    rt5 = RowPartition.from_row_splits(splits5)
    self.assertEqual(rt1.row_splits().dtype, dtypes.int64)
    self.assertEqual(rt2.row_splits().dtype, dtypes.int64)
    self.assertEqual(rt3.row_splits().dtype, dtypes.int32)
    self.assertEqual(rt4.row_splits().dtype, dtypes.int64)
    self.assertEqual(rt5.row_splits().dtype, dtypes.int32)

  def testFromRowSplitsWithEmptySplits(self):
    err_msg = 'row_splits tensor may not be empty'
    with self.assertRaisesRegex(ValueError, err_msg):
      RowPartition.from_row_splits([])

  def testFromRowStarts(self):
    nvals = constant_op.constant(7)
    row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64)

    rp = RowPartition.from_row_starts(row_starts, nvals, validate=False)
    self.assertEqual(rp.dtype, dtypes.int64)

    rp_row_starts = rp.row_starts()
    rp_row_splits = rp.row_splits()
    rp_nrows = rp.nrows()

    self.assertAllEqual(rp_nrows, 5)
    self.assertAllEqual(rp_row_starts, row_starts)
    self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])

  def testFromRowLimits(self):
    row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64)

    rp = RowPartition.from_row_limits(row_limits, validate=False)
    self.assertEqual(rp.dtype, dtypes.int64)

    rp_row_limits = rp.row_limits()
    rp_row_splits = rp.row_splits()
    rp_nrows = rp.nrows()

    self.assertAllEqual(rp_nrows, 5)
    self.assertAllEqual(rp_row_limits, row_limits)
    self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])

  def testFromRowLengths(self):
    row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64)

    rp = RowPartition.from_row_lengths(row_lengths, validate=False)
    self.assertEqual(rp.dtype, dtypes.int64)

    rp_row_lengths = rp.row_lengths()
    rp_nrows = rp.nrows()

    self.assertIs(rp_row_lengths, row_lengths)  # nrows
    self.assertAllEqual(rp_nrows, 5)
    self.assertAllEqual(rp_row_lengths, row_lengths)

  def testFromUniformRowLength(self):
    nvals = 16
    a1 = RowPartition.from_uniform_row_length(
        nvals=nvals, uniform_row_length=2)
    self.assertAllEqual(a1.uniform_row_length(), 2)
    self.assertAllEqual(a1.nrows(), 8)

  def testFromUniformRowLengthWithEmptyValues(self):
    a = RowPartition.from_uniform_row_length(
        nvals=0, uniform_row_length=0, nrows=10)
    self.assertEqual(self.evaluate(a.nvals()), 0)
    self.assertEqual(self.evaluate(a.nrows()), 10)

  def testFromUniformRowLengthWithPlaceholders1(self):
    nvals = array_ops.placeholder_with_default(
        constant_op.constant(6, dtype=dtypes.int64), None)
    rt1 = RowPartition.from_uniform_row_length(
        nvals=nvals, uniform_row_length=3)
    const_nvals1 = self.evaluate(rt1.nvals())
    self.assertEqual(const_nvals1, 6)

  def testFromUniformRowLengthWithPlaceholders2(self):
    nvals = array_ops.placeholder_with_default(6, None)
    ph_rowlen = array_ops.placeholder_with_default(3, None)
    rt2 = RowPartition.from_uniform_row_length(
        nvals=nvals, uniform_row_length=ph_rowlen)
    const_nvals2 = self.evaluate(rt2.nvals())
    self.assertEqual(const_nvals2, 6)

  def testFromValueRowIdsWithBadNRows(self):
    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
    nrows = constant_op.constant(5, dtypes.int64)

    with self.assertRaisesRegex(ValueError, r'Expected nrows >= 0; got -2'):
      RowPartition.from_value_rowids(
          value_rowids=array_ops.placeholder_with_default(value_rowids, None),
          nrows=-2)

    with self.assertRaisesRegex(
        ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, '
        r'value_rowids\[-1\]=4'):
      RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=2)

    with self.assertRaisesRegex(
        ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, '
        r'value_rowids\[-1\]=4'):
      RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=4)

    with self.assertRaisesRegex(ValueError, r'Shape \(7, 1\) must have rank 1'):
      RowPartition.from_value_rowids(
          value_rowids=array_ops.expand_dims(value_rowids, 1), nrows=nrows)

    with self.assertRaisesRegex(ValueError, r'Shape \(1,\) must have rank 0'):
      RowPartition.from_value_rowids(
          value_rowids=value_rowids, nrows=array_ops.expand_dims(nrows, 0))

  #=============================================================================
  # RowPartition.__str__
  #=============================================================================
  def testRowPartitionStr(self):
    row_splits = [0, 2, 5, 6, 6, 7]
    rp = RowPartition.from_row_splits(row_splits, validate=False)
    splits_type = 'int64'
    if context.executing_eagerly():
      expected_repr = ('tf.RowPartition(row_splits=tf.Tensor([0 2 5 6 6 7], '
                       'shape=(6,), dtype=int64))')
    else:
      expected_repr = ('tf.RowPartition(row_splits='
                       'Tensor("RowPartitionFromRowSplits/row_splits:0", '
                       'shape=(6,), dtype={}))').format(splits_type)
    self.assertEqual(repr(rp), expected_repr)
    self.assertEqual(str(rp), expected_repr)

  @parameterized.parameters([
      # from_value_rowids
      {
          'descr': 'bad rank for value_rowids',
          'factory': RowPartition.from_value_rowids,
          'value_rowids': [[1, 2], [3, 4]],
          'nrows': 10
      },
      {
          'descr': 'bad rank for nrows',
          'factory': RowPartition.from_value_rowids,
          'value_rowids': [1, 2, 3, 4],
          'nrows': [10]
      },
      {
          'descr': 'negative value_rowid',
          'factory': RowPartition.from_value_rowids,
          'value_rowids': [-5, 2, 3, 4],
          'nrows': 10
      },
      {
          'descr': 'non-monotonic-increasing value_rowid',
          'factory': RowPartition.from_value_rowids,
          'value_rowids': [4, 3, 2, 1],
          'nrows': 10
      },
      {
          'descr': 'value_rowid > nrows',
          'factory': RowPartition.from_value_rowids,
          'value_rowids': [1, 2, 3, 4],
          'nrows': 2
      },

      # from_row_splits
      {
          'descr': 'bad rank for row_splits',
          'factory': RowPartition.from_row_splits,
          'row_splits': [[1, 2], [3, 4]]
      },
      {
          'descr': 'row_splits[0] != 0',
          'factory': RowPartition.from_row_splits,
          'row_splits': [2, 3, 4]
      },
      {
          'descr': 'non-monotonic-increasing row_splits',
          'factory': RowPartition.from_row_splits,
          'row_splits': [0, 3, 2, 4]
      },

      # from_row_lengths
      {
          'descr': 'bad rank for row_lengths',
          'factory': RowPartition.from_row_lengths,
          'row_lengths': [[1, 2], [1, 0]]
      },
      {
          'descr': 'negatve row_lengths',
          'factory': RowPartition.from_row_lengths,
          'row_lengths': [3, -1, 2]
      },

      # from_row_starts
      {
          'descr': 'bad rank for row_starts',
          'factory': RowPartition.from_row_starts,
          'nvals': 2,
          'row_starts': [[1, 2], [3, 4]]
      },
      {
          'descr': 'row_starts[0] != 0',
          'factory': RowPartition.from_row_starts,
          'nvals': 5,
          'row_starts': [2, 3, 4]
      },
      {
          'descr': 'non-monotonic-increasing row_starts',
          'factory': RowPartition.from_row_starts,
          'nvals': 4,
          'row_starts': [0, 3, 2, 4]
      },
      {
          'descr': 'row_starts[0] > nvals',
          'factory': RowPartition.from_row_starts,
          'nvals': 4,
          'row_starts': [0, 2, 3, 5]
      },

      # from_row_limits
      {
          'descr': 'bad rank for row_limits',
          'factory': RowPartition.from_row_limits,
          'row_limits': [[1, 2], [3, 4]]
      },
      {
          'descr': 'row_limits[0] < 0',
          'factory': RowPartition.from_row_limits,
          'row_limits': [-1, 3, 4]
      },
      {
          'descr': 'non-monotonic-increasing row_limits',
          'factory': RowPartition.from_row_limits,
          'row_limits': [0, 3, 2, 4]
      },

      # from_uniform_row_length
      {
          'descr': 'rowlen * nrows != nvals (1)',
          'factory': RowPartition.from_uniform_row_length,
          'nvals': 5,
          'uniform_row_length': 3
      },
      {
          'descr': 'rowlen * nrows != nvals (2)',
          'factory': RowPartition.from_uniform_row_length,
          'nvals': 5,
          'uniform_row_length': 6
      },
      {
          'descr': 'rowlen * nrows != nvals (3)',
          'factory': RowPartition.from_uniform_row_length,
          'nvals': 6,
          'uniform_row_length': 3,
          'nrows': 3
      },
      {
          'descr': 'rowlen must be a scalar',
          'factory': RowPartition.from_uniform_row_length,
          'nvals': 4,
          'uniform_row_length': [2]
      },
      {
          'descr': 'rowlen must be nonnegative',
          'factory': RowPartition.from_uniform_row_length,
          'nvals': 4,
          'uniform_row_length': -1
      },
  ])
  def testFactoryValidation(self, descr, factory, **kwargs):
    # When input tensors have shape information, some of these errors will be
    # detected statically.
    with self.assertRaises((errors.InvalidArgumentError, ValueError)):
      partition = factory(**kwargs)
      self.evaluate(partition.row_splits())

    # Remove shape information (by wrapping tensors in placeholders), and check
    # that we detect the errors when the graph is run.
    if not context.executing_eagerly():

      def wrap_arg(v):
        return array_ops.placeholder_with_default(
            constant_op.constant(v, dtype=dtypes.int64),
            tensor_shape.TensorShape(None))

      kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items())

      with self.assertRaises(errors.InvalidArgumentError):
        partition = factory(**kwargs)
        self.evaluate(partition.row_splits())

  @parameterized.named_parameters([
      ('FromRowSplits', lambda: RowPartition.from_row_splits([0, 2, 8]),
       ['row_splits']),
      ('FromRowLengths', lambda: RowPartition.from_row_lengths([3, 0, 8]),
       ['row_splits', 'row_lengths']),
      ('FromValueRowIds',
       lambda: RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4]),
       ['row_splits', 'value_rowids', 'row_lengths', 'nrows']),
      ('FromRowStarts',
       lambda: RowPartition.from_row_starts([0, 3, 7], nvals=10),
       ['row_splits']),
      ('FromRowLimits', lambda: RowPartition.from_row_limits([3, 7, 10]),
       ['row_splits']),
  ])
  def testPrecomputedSplits(self, rp_factory, expected_encodings):
    rp = rp_factory()
    self.assertEqual(rp.has_precomputed_row_splits(),
                     'row_splits' in expected_encodings)
    self.assertEqual(rp.has_precomputed_row_lengths(),
                     'row_lengths' in expected_encodings)
    self.assertEqual(rp.has_precomputed_value_rowids(),
                     'value_rowids' in expected_encodings)
    self.assertEqual(rp.has_precomputed_nrows(), 'nrows' in expected_encodings)

  def testWithPrecomputedSplits(self):
    rp = RowPartition.from_row_splits([0, 2, 8])

    rp_with_row_splits = rp.with_precomputed_row_splits()
    self.assertTrue(rp_with_row_splits.has_precomputed_row_splits())

    self.assertFalse(rp.has_precomputed_row_lengths())
    rp_with_row_lengths = rp.with_precomputed_row_lengths()
    self.assertTrue(rp_with_row_lengths.has_precomputed_row_lengths())

    self.assertFalse(rp.has_precomputed_value_rowids())
    rp_with_value_rowids = rp.with_precomputed_value_rowids()
    self.assertTrue(rp_with_value_rowids.has_precomputed_value_rowids())

    self.assertFalse(rp.has_precomputed_nrows())
    rp_with_nrows = rp.with_precomputed_nrows()
    self.assertTrue(rp_with_nrows.has_precomputed_nrows())

  @parameterized.named_parameters([
      dict(
          testcase_name='FromRowSplitsAndRowSplits',
          x=lambda: RowPartition.from_row_splits([0, 3, 8]),
          y=lambda: RowPartition.from_row_splits([0, 3, 8]),
          expected_encodings=['row_splits']),
      dict(
          testcase_name='FromRowSplitsAndUniformRowLength',
          x=lambda: RowPartition.from_row_splits([0, 3, 6]),
          y=lambda: RowPartition.from_uniform_row_length(3, nvals=6),
          expected_encodings=['row_splits', 'uniform_row_length', 'nrows']),
      dict(
          testcase_name='FromRowSplitsAndRowLengths',
          x=lambda: RowPartition.from_row_splits([0, 3, 8]),
          y=lambda: RowPartition.from_row_lengths([3, 5]),
          expected_encodings=['row_splits', 'row_lengths']),
      dict(
          testcase_name='FromRowSplitsAndValueRowIds',
          x=lambda: RowPartition.from_row_splits([0, 3, 8]),
          y=lambda: RowPartition.from_value_rowids([0, 0, 0, 1, 1, 1, 1, 1]),
          expected_encodings=[
              'row_splits', 'row_lengths', 'value_rowids', 'nrows'
          ]),
      dict(
          testcase_name='FromRowSplitsAndRowSplitsPlusNRows',
          x=lambda: RowPartition.from_row_splits([0, 3, 8]),
          y=lambda: RowPartition.from_row_splits([0, 3, 8]).
          with_precomputed_nrows(),
          expected_encodings=['row_splits', 'nrows']),
  ])
  def testMergePrecomputedEncodings(self, x, y, expected_encodings):
    x = x()
    y = y()
    for validate in (True, False):
      result = x.merge_precomputed_encodings(y, validate)
      self.assertEqual(result.has_precomputed_row_splits(),
                       'row_splits' in expected_encodings)
      self.assertEqual(result.has_precomputed_row_lengths(),
                       'row_lengths' in expected_encodings)
      self.assertEqual(result.has_precomputed_value_rowids(),
                       'value_rowids' in expected_encodings)
      self.assertEqual(result.has_precomputed_nrows(),
                       'nrows' in expected_encodings)
      self.assertEqual(result.uniform_row_length() is not None,
                       'uniform_row_length' in expected_encodings)
      for r in (x, y):
        if (r.has_precomputed_row_splits() and
            result.has_precomputed_row_splits()):
          self.assertAllEqual(r.row_splits(), result.row_splits())
        if (r.has_precomputed_row_lengths() and
            result.has_precomputed_row_lengths()):
          self.assertAllEqual(r.row_lengths(), result.row_lengths())
        if (r.has_precomputed_value_rowids() and
            result.has_precomputed_value_rowids()):
          self.assertAllEqual(r.value_rowids(), result.value_rowids())
        if r.has_precomputed_nrows() and result.has_precomputed_nrows():
          self.assertAllEqual(r.nrows(), result.nrows())
        if (r.uniform_row_length() is not None and
            result.uniform_row_length() is not None):
          self.assertAllEqual(r.uniform_row_length(),
                              result.uniform_row_length())

  def testMergePrecomputedEncodingsFastPaths(self):
    # Same object: x gets returned as-is.
    x = RowPartition.from_row_splits([0, 3, 8, 8])
    self.assertIs(x.merge_precomputed_encodings(x), x)

    # Same encoding tensor objects: x gets returned as-is.
    y = RowPartition.from_row_splits(x.row_splits(), validate=False)
    self.assertIs(x.merge_precomputed_encodings(y), x)

  def testMergePrecomputedEncodingsWithMatchingTensors(self):
    # The encoding tensors for `a` are a superset of the encoding tensors
    # for `b`, and where they overlap, they the same tensor objects.
    a = RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4])
    b = RowPartition.from_row_splits(a.row_splits(), validate=False)
    self.assertIs(a.merge_precomputed_encodings(b), a)
    self.assertIs(b.merge_precomputed_encodings(a), a)
    self.assertIsNot(a, b)

  @parameterized.named_parameters([
      dict(
          testcase_name='RowSplitMismatch',
          x=lambda: RowPartition.from_row_splits([0, 3, 8]),
          y=lambda: RowPartition.from_row_splits([0, 3, 8, 9]),
          message='incompatible row_splits'),
      dict(
          testcase_name='RowLengthMismatch',
          x=lambda: RowPartition.from_row_lengths([2, 0, 2]),
          y=lambda: RowPartition.from_row_lengths([2, 0, 2, 1]),
          message='incompatible row_splits'),  # row_splits is checked first
      dict(
          testcase_name='ValueRowIdMismatch',
          x=lambda: RowPartition.from_value_rowids([0, 3, 3, 4]),
          y=lambda: RowPartition.from_value_rowids([0, 3, 4]),
          message='incompatible value_rowids'),
  ])
  def testMergePrecomputedEncodingStaticErrors(self, x, y, message):
    if context.executing_eagerly():
      return
    # Errors that are caught by static shape checks.
    x = x()
    y = y()
    with self.assertRaisesRegex(ValueError, message):
      x.merge_precomputed_encodings(y).row_splits()
    with self.assertRaisesRegex(ValueError, message):
      y.merge_precomputed_encodings(x).row_splits()

  @parameterized.named_parameters([
      dict(
          testcase_name='NRowsMismatch',
          x=lambda: RowPartition.from_uniform_row_length(5, nvals=20),
          y=lambda: RowPartition.from_uniform_row_length(5, nvals=15),
          message='incompatible nrows'),
      dict(
          testcase_name='UniformRowLengthMismatch',
          x=lambda: RowPartition.from_uniform_row_length(5, nvals=20),
          y=lambda: RowPartition.from_uniform_row_length(2, nvals=8),
          message='incompatible uniform_row_length'),
      dict(
          testcase_name='RowSplitMismatch',
          x=lambda: RowPartition.from_row_splits([0, 3, 8]),
          y=lambda: RowPartition.from_row_splits([0, 5, 8]),
          message='incompatible row_splits'),
      dict(
          testcase_name='RowLengthMismatch',
          x=lambda: RowPartition.from_row_lengths([2, 0, 2]),
          y=lambda: RowPartition.from_row_lengths([0, 0, 2]),
          message='incompatible row_splits'),  # row_splits is checked first
      dict(
          testcase_name='ValueRowIdMismatch',
          x=lambda: RowPartition.from_value_rowids([0, 3, 3]),
          y=lambda: RowPartition.from_value_rowids([0, 0, 3]),
          message='incompatible row_splits'),  # row_splits is checked first
  ])
  def testMergePrecomputedEncodingRuntimeErrors(self, x, y, message):
    # Errors that are caught by runtime value checks.
    x = x()
    y = y()
    with self.assertRaisesRegex(errors.InvalidArgumentError, message):
      self.evaluate(x.merge_precomputed_encodings(y).row_splits())
    with self.assertRaisesRegex(errors.InvalidArgumentError, message):
      self.evaluate(y.merge_precomputed_encodings(x).row_splits())
Example #4
0
class RowPartitionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
    #=============================================================================
    # RowPartition class docstring examples
    #=============================================================================

    def testClassDocStringExamples(self):
        # From section: "Component Tensors"
        rp = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
        self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8])
        del rp

        # From section: "Alternative Row-Partitioning Schemes"
        rt1 = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
        rt2 = RowPartition.from_row_lengths(row_lengths=[4, 0, 3, 1, 0])
        rt3 = RowPartition.from_value_rowids(
            value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
        rt4 = RowPartition.from_row_starts(row_starts=[0, 4, 4, 7, 8], nvals=8)
        rt5 = RowPartition.from_row_limits(row_limits=[4, 4, 7, 8, 8])
        for rp in (rt1, rt2, rt3, rt4, rt5):
            self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8])
        del rt1, rt2, rt3, rt4, rt5

        # From section: "Multiple Ragged Dimensions"
        inner_rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
        outer_rt = RowPartition.from_row_splits(row_splits=[0, 3, 3, 5])
        del inner_rt, outer_rt

    #=============================================================================
    # RowPartition Constructor (private)
    #=============================================================================

    def testRowPartitionConstruction(self):
        row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
        rp = RowPartition(row_splits=row_splits,
                          internal=row_partition._row_partition_factory_key)
        self.assertAllEqual(rp.row_splits(), [0, 2, 2, 5, 6, 7])

    def testRowPartitionConstructionErrors(self):
        row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)

        with self.assertRaisesRegex(ValueError,
                                    'RowPartition constructor is private'):
            RowPartition(row_splits=row_splits)

        with self.assertRaisesRegex(
                TypeError, 'Row-partitioning argument must be a Tensor'):
            RowPartition(row_splits=[0, 2, 2, 5, 6, 7],
                         internal=row_partition._row_partition_factory_key)

        with self.assertRaisesRegex(ValueError,
                                    r'Shape \(6, 1\) must have rank 1'):
            RowPartition(row_splits=array_ops.expand_dims(row_splits, 1),
                         internal=row_partition._row_partition_factory_key)

        with self.assertRaisesRegex(TypeError,
                                    'Cached value must be a Tensor or None.'):
            RowPartition(row_splits=row_splits,
                         row_lengths=[2, 3, 4],
                         internal=row_partition._row_partition_factory_key)

        with self.assertRaisesRegex(ValueError, 'Inconsistent dtype'):
            RowPartition(row_splits=constant_op.constant([0, 3], dtypes.int64),
                         nrows=constant_op.constant(1, dtypes.int32),
                         internal=row_partition._row_partition_factory_key)

    #=============================================================================
    # RowPartition Factory Ops
    #=============================================================================

    def testFromValueRowIdsWithDerivedNRows(self):
        # nrows is known at graph creation time.
        value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4],
                                            dtypes.int64)
        # TODO(martinz): add nrows
        rp = RowPartition.from_value_rowids(value_rowids, validate=False)
        self.assertEqual(rp.dtype, dtypes.int64)

        rp_row_splits = rp.row_splits()
        rp_value_rowids = rp.value_rowids()
        rp_nrows = rp.nrows()

        self.assertIs(rp_value_rowids, value_rowids)  # value_rowids
        self.assertAllEqual(rp_value_rowids, value_rowids)
        self.assertAllEqual(rp_nrows, 5)
        self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])

    def testFromValueRowIdsWithDerivedNRowsDynamic(self):
        # nrows is not known at graph creation time.
        value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4],
                                            dtypes.int64)
        value_rowids = array_ops.placeholder_with_default(value_rowids,
                                                          shape=None)

        rp = RowPartition.from_value_rowids(value_rowids, validate=False)

        rp_value_rowids = rp.value_rowids()
        rp_nrows = rp.nrows()

        self.assertIs(rp_value_rowids, value_rowids)  # value_rowids
        self.assertAllEqual(rp_value_rowids, value_rowids)
        self.assertAllEqual(rp_nrows, 5)

    def testFromValueRowIdsWithExplicitNRows(self):
        value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4],
                                            dtypes.int64)
        nrows = constant_op.constant(7, dtypes.int64)

        rp = RowPartition.from_value_rowids(value_rowids,
                                            nrows,
                                            validate=False)

        rp_value_rowids = rp.value_rowids()
        rp_nrows = rp.nrows()
        rp_row_splits = rp.row_splits()

        self.assertIs(rp_value_rowids, value_rowids)  # value_rowids
        self.assertIs(rp_nrows, nrows)  # nrows
        self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7, 7, 7])

    def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
        value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4],
                                            dtypes.int64)
        nrows = constant_op.constant(5, dtypes.int64)

        rp = RowPartition.from_value_rowids(value_rowids,
                                            nrows,
                                            validate=False)

        rp_value_rowids = rp.value_rowids()
        rp_nrows = rp.nrows()
        rp_row_splits = rp.row_splits()

        self.assertIs(rp_value_rowids, value_rowids)  # value_rowids
        self.assertIs(rp_nrows, nrows)  # nrows
        self.assertAllEqual(rp_value_rowids, value_rowids)
        self.assertAllEqual(rp_nrows, nrows)
        self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])

    def testFromValueRowIdsWithEmptyValues(self):
        rp = RowPartition.from_value_rowids([])
        rp_nrows = rp.nrows()
        self.assertEqual(rp.dtype, dtypes.int64)
        self.assertEqual(rp.value_rowids().shape.as_list(), [0])
        self.assertAllEqual(rp_nrows, 0)

    def testFromRowSplits(self):
        row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)

        rp = RowPartition.from_row_splits(row_splits, validate=False)
        self.assertEqual(rp.dtype, dtypes.int64)

        rp_row_splits = rp.row_splits()
        rp_nrows = rp.nrows()

        self.assertIs(rp_row_splits, row_splits)
        self.assertAllEqual(rp_nrows, 5)

    def testFromRowSplitsWithDifferentSplitTypes(self):
        splits1 = [0, 2, 2, 5, 6, 7]
        splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64)
        splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32)
        splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
        splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32)
        rt1 = RowPartition.from_row_splits(splits1)
        rt2 = RowPartition.from_row_splits(splits2)
        rt3 = RowPartition.from_row_splits(splits3)
        rt4 = RowPartition.from_row_splits(splits4)
        rt5 = RowPartition.from_row_splits(splits5)
        self.assertEqual(rt1.row_splits().dtype, dtypes.int64)
        self.assertEqual(rt2.row_splits().dtype, dtypes.int64)
        self.assertEqual(rt3.row_splits().dtype, dtypes.int32)
        self.assertEqual(rt4.row_splits().dtype, dtypes.int64)
        self.assertEqual(rt5.row_splits().dtype, dtypes.int32)

    def testFromRowSplitsWithEmptySplits(self):
        err_msg = 'row_splits tensor may not be empty'
        with self.assertRaisesRegex(ValueError, err_msg):
            RowPartition.from_row_splits([])

    def testFromRowStarts(self):
        nvals = constant_op.constant(7)
        row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64)

        rp = RowPartition.from_row_starts(row_starts, nvals, validate=False)
        self.assertEqual(rp.dtype, dtypes.int64)

        rp_row_starts = rp.row_starts()
        rp_row_splits = rp.row_splits()
        rp_nrows = rp.nrows()

        self.assertAllEqual(rp_nrows, 5)
        self.assertAllEqual(rp_row_starts, row_starts)
        self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])

    def testFromRowLimits(self):
        row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64)

        rp = RowPartition.from_row_limits(row_limits, validate=False)
        self.assertEqual(rp.dtype, dtypes.int64)

        rp_row_limits = rp.row_limits()
        rp_row_splits = rp.row_splits()
        rp_nrows = rp.nrows()

        self.assertAllEqual(rp_nrows, 5)
        self.assertAllEqual(rp_row_limits, row_limits)
        self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])

    def testFromRowLengths(self):
        row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64)

        rp = RowPartition.from_row_lengths(row_lengths, validate=False)
        self.assertEqual(rp.dtype, dtypes.int64)

        rp_row_lengths = rp.row_lengths()
        rp_nrows = rp.nrows()

        self.assertIs(rp_row_lengths, row_lengths)  # nrows
        self.assertAllEqual(rp_nrows, 5)
        self.assertAllEqual(rp_row_lengths, row_lengths)

    def testFromUniformRowLength(self):
        nvals = 16
        a1 = RowPartition.from_uniform_row_length(nvals=nvals,
                                                  uniform_row_length=2)
        self.assertAllEqual(a1.uniform_row_length(), 2)
        self.assertAllEqual(a1.nrows(), 8)

    def testFromUniformRowLengthWithEmptyValues(self):
        a = RowPartition.from_uniform_row_length(nvals=0,
                                                 uniform_row_length=0,
                                                 nrows=10)
        self.assertEqual(self.evaluate(a.nvals()), 0)
        self.assertEqual(self.evaluate(a.nrows()), 10)

    def testFromUniformRowLengthWithPlaceholders1(self):
        nvals = array_ops.placeholder_with_default(
            constant_op.constant(6, dtype=dtypes.int64), None)
        rt1 = RowPartition.from_uniform_row_length(nvals=nvals,
                                                   uniform_row_length=3)
        const_nvals1 = self.evaluate(rt1.nvals())
        self.assertEqual(const_nvals1, 6)

    def testFromUniformRowLengthWithPlaceholders2(self):
        nvals = array_ops.placeholder_with_default(6, None)
        ph_rowlen = array_ops.placeholder_with_default(3, None)
        rt2 = RowPartition.from_uniform_row_length(
            nvals=nvals, uniform_row_length=ph_rowlen)
        const_nvals2 = self.evaluate(rt2.nvals())
        self.assertEqual(const_nvals2, 6)

    def testFromValueRowIdsWithBadNRows(self):
        value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4],
                                            dtypes.int64)
        nrows = constant_op.constant(5, dtypes.int64)

        with self.assertRaisesRegex(ValueError,
                                    r'Expected nrows >= 0; got -2'):
            RowPartition.from_value_rowids(
                value_rowids=array_ops.placeholder_with_default(
                    value_rowids, None),
                nrows=-2)

        with self.assertRaisesRegex(
                ValueError,
                r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, '
                r'value_rowids\[-1\]=4'):
            RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=2)

        with self.assertRaisesRegex(
                ValueError,
                r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, '
                r'value_rowids\[-1\]=4'):
            RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=4)

        with self.assertRaisesRegex(ValueError,
                                    r'Shape \(7, 1\) must have rank 1'):
            RowPartition.from_value_rowids(value_rowids=array_ops.expand_dims(
                value_rowids, 1),
                                           nrows=nrows)

        with self.assertRaisesRegex(ValueError,
                                    r'Shape \(1,\) must have rank 0'):
            RowPartition.from_value_rowids(value_rowids=value_rowids,
                                           nrows=array_ops.expand_dims(
                                               nrows, 0))

    #=============================================================================
    # RowPartition.__str__
    #=============================================================================
    def testRowPartitionStr(self):
        row_splits = [0, 2, 5, 6, 6, 7]
        rp = RowPartition.from_row_splits(row_splits, validate=False)
        if context.executing_eagerly():
            expected_repr = 'tf.RowPartition(row_splits=[0 2 5 6 6 7])'
        else:
            expected_repr = (
                'tf.RowPartition(row_splits='
                'Tensor("RowPartitionFromRowSplits/row_splits:0", '
                'shape=(6,), dtype=int64))')
        self.assertEqual(repr(rp), expected_repr)
        self.assertEqual(str(rp), expected_repr)

    def testRowPartitionStrUniformRowLength(self):
        rp = RowPartition.from_uniform_row_length(5, nvals=10, nrows=2)
        if context.executing_eagerly():
            expected_repr = ('tf.RowPartition(nrows=2, uniform_row_length=5)')
        else:
            expected_repr = (
                'tf.RowPartition(nrows='
                'Tensor("RowPartitionFromUniformRowLength/'
                'nrows:0", shape=(), dtype=int64), '
                'uniform_row_length=Tensor("RowPartitionFromUniformRowLength/'
                'uniform_row_length:0", shape=(), dtype=int64))')
        self.assertEqual(repr(rp), expected_repr)
        self.assertEqual(str(rp), expected_repr)

    @parameterized.parameters([
        # from_value_rowids
        {
            'descr': 'bad rank for value_rowids',
            'factory': RowPartition.from_value_rowids,
            'value_rowids': [[1, 2], [3, 4]],
            'nrows': 10
        },
        {
            'descr': 'bad rank for nrows',
            'factory': RowPartition.from_value_rowids,
            'value_rowids': [1, 2, 3, 4],
            'nrows': [10]
        },
        {
            'descr': 'negative value_rowid',
            'factory': RowPartition.from_value_rowids,
            'value_rowids': [-5, 2, 3, 4],
            'nrows': 10
        },
        {
            'descr': 'non-monotonic-increasing value_rowid',
            'factory': RowPartition.from_value_rowids,
            'value_rowids': [4, 3, 2, 1],
            'nrows': 10
        },
        {
            'descr': 'value_rowid > nrows',
            'factory': RowPartition.from_value_rowids,
            'value_rowids': [1, 2, 3, 4],
            'nrows': 2
        },

        # from_row_splits
        {
            'descr': 'bad rank for row_splits',
            'factory': RowPartition.from_row_splits,
            'row_splits': [[1, 2], [3, 4]]
        },
        {
            'descr': 'row_splits[0] != 0',
            'factory': RowPartition.from_row_splits,
            'row_splits': [2, 3, 4]
        },
        {
            'descr': 'non-monotonic-increasing row_splits',
            'factory': RowPartition.from_row_splits,
            'row_splits': [0, 3, 2, 4]
        },

        # from_row_lengths
        {
            'descr': 'bad rank for row_lengths',
            'factory': RowPartition.from_row_lengths,
            'row_lengths': [[1, 2], [1, 0]]
        },
        {
            'descr': 'negatve row_lengths',
            'factory': RowPartition.from_row_lengths,
            'row_lengths': [3, -1, 2]
        },

        # from_row_starts
        {
            'descr': 'bad rank for row_starts',
            'factory': RowPartition.from_row_starts,
            'nvals': 2,
            'row_starts': [[1, 2], [3, 4]]
        },
        {
            'descr': 'row_starts[0] != 0',
            'factory': RowPartition.from_row_starts,
            'nvals': 5,
            'row_starts': [2, 3, 4]
        },
        {
            'descr': 'non-monotonic-increasing row_starts',
            'factory': RowPartition.from_row_starts,
            'nvals': 4,
            'row_starts': [0, 3, 2, 4]
        },
        {
            'descr': 'row_starts[0] > nvals',
            'factory': RowPartition.from_row_starts,
            'nvals': 4,
            'row_starts': [0, 2, 3, 5]
        },

        # from_row_limits
        {
            'descr': 'bad rank for row_limits',
            'factory': RowPartition.from_row_limits,
            'row_limits': [[1, 2], [3, 4]]
        },
        {
            'descr': 'row_limits[0] < 0',
            'factory': RowPartition.from_row_limits,
            'row_limits': [-1, 3, 4]
        },
        {
            'descr': 'non-monotonic-increasing row_limits',
            'factory': RowPartition.from_row_limits,
            'row_limits': [0, 3, 2, 4]
        },

        # from_uniform_row_length
        {
            'descr': 'rowlen * nrows != nvals (1)',
            'factory': RowPartition.from_uniform_row_length,
            'nvals': 5,
            'uniform_row_length': 3
        },
        {
            'descr': 'rowlen * nrows != nvals (2)',
            'factory': RowPartition.from_uniform_row_length,
            'nvals': 5,
            'uniform_row_length': 6
        },
        {
            'descr': 'rowlen * nrows != nvals (3)',
            'factory': RowPartition.from_uniform_row_length,
            'nvals': 6,
            'uniform_row_length': 3,
            'nrows': 3
        },
        {
            'descr': 'rowlen must be a scalar',
            'factory': RowPartition.from_uniform_row_length,
            'nvals': 4,
            'uniform_row_length': [2]
        },
        {
            'descr': 'rowlen must be nonnegative',
            'factory': RowPartition.from_uniform_row_length,
            'nvals': 4,
            'uniform_row_length': -1
        },
    ])
    def testFactoryValidation(self, descr, factory, **kwargs):
        # When input tensors have shape information, some of these errors will be
        # detected statically.
        with self.assertRaises((errors.InvalidArgumentError, ValueError)):
            partition = factory(**kwargs)
            self.evaluate(partition.row_splits())

        # Remove shape information (by wrapping tensors in placeholders), and check
        # that we detect the errors when the graph is run.
        if not context.executing_eagerly():

            def wrap_arg(v):
                return array_ops.placeholder_with_default(
                    constant_op.constant(v, dtype=dtypes.int64),
                    tensor_shape.TensorShape(None))

            kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items())

            with self.assertRaises(errors.InvalidArgumentError):
                partition = factory(**kwargs)
                self.evaluate(partition.row_splits())

    @parameterized.named_parameters([
        ('FromRowSplits', lambda: RowPartition.from_row_splits([0, 2, 8]),
         ['row_splits']),
        ('FromRowLengths', lambda: RowPartition.from_row_lengths([3, 0, 8]),
         ['row_splits', 'row_lengths']),
        ('FromValueRowIds',
         lambda: RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4]),
         ['row_splits', 'value_rowids', 'row_lengths', 'nrows']),
        ('FromRowStarts',
         lambda: RowPartition.from_row_starts([0, 3, 7], nvals=10),
         ['row_splits']),
        ('FromRowLimits', lambda: RowPartition.from_row_limits([3, 7, 10]),
         ['row_splits']),
    ])
    def testPrecomputedSplits(self, rp_factory, expected_encodings):
        rp = rp_factory()
        self.assertEqual(rp._has_precomputed_row_splits(), 'row_splits'
                         in expected_encodings)
        self.assertEqual(rp._has_precomputed_row_lengths(), 'row_lengths'
                         in expected_encodings)
        self.assertEqual(rp._has_precomputed_value_rowids(), 'value_rowids'
                         in expected_encodings)
        self.assertEqual(rp._has_precomputed_nrows(), 'nrows'
                         in expected_encodings)

    def testWithPrecomputedSplits(self):
        rp = RowPartition.from_row_splits([0, 2, 8])

        rp_with_row_splits = rp._with_precomputed_row_splits()
        self.assertTrue(rp_with_row_splits._has_precomputed_row_splits())

        self.assertFalse(rp._has_precomputed_row_lengths())
        rp_with_row_lengths = rp._with_precomputed_row_lengths()
        self.assertTrue(rp_with_row_lengths._has_precomputed_row_lengths())

        self.assertFalse(rp._has_precomputed_value_rowids())
        rp_with_value_rowids = rp._with_precomputed_value_rowids()
        self.assertTrue(rp_with_value_rowids._has_precomputed_value_rowids())

        self.assertFalse(rp._has_precomputed_nrows())
        rp_with_nrows = rp._with_precomputed_nrows()
        self.assertTrue(rp_with_nrows._has_precomputed_nrows())

        self.assertFalse(rp._has_precomputed_nvals())
        rp_with_nvals = rp._with_precomputed_nvals()
        self.assertTrue(rp_with_nvals._has_precomputed_nvals())

    @parameterized.named_parameters([
        dict(testcase_name='FromRowSplitsAndRowSplits',
             x=lambda: RowPartition.from_row_splits([0, 3, 8]),
             y=lambda: RowPartition.from_row_splits([0, 3, 8]),
             expected_encodings=['row_splits']),
        dict(testcase_name='FromRowSplitsAndUniformRowLength',
             x=lambda: RowPartition.from_row_splits([0, 3, 6]),
             y=lambda: RowPartition.from_uniform_row_length(3, nvals=6),
             expected_encodings=['row_splits', 'uniform_row_length', 'nrows']),
        dict(testcase_name='FromRowSplitsAndRowLengths',
             x=lambda: RowPartition.from_row_splits([0, 3, 8]),
             y=lambda: RowPartition.from_row_lengths([3, 5]),
             expected_encodings=['row_splits', 'row_lengths']),
        dict(
            testcase_name='FromRowSplitsAndValueRowIds',
            x=lambda: RowPartition.from_row_splits([0, 3, 8]),
            y=lambda: RowPartition.from_value_rowids([0, 0, 0, 1, 1, 1, 1, 1]),
            expected_encodings=[
                'row_splits', 'row_lengths', 'value_rowids', 'nrows'
            ]),
        dict(testcase_name='FromRowSplitsAndRowSplitsPlusNRows',
             x=lambda: RowPartition.from_row_splits([0, 3, 8]),
             y=lambda: RowPartition.from_row_splits([0, 3, 8]).
             _with_precomputed_nrows(),
             expected_encodings=['row_splits', 'nrows']),
    ])
    def testMergePrecomputedEncodings(self, x, y, expected_encodings):
        x = x()
        y = y()
        for validate in (True, False):
            result = x._merge_precomputed_encodings(y, validate)
            self.assertEqual(result._has_precomputed_row_splits(), 'row_splits'
                             in expected_encodings)
            self.assertEqual(result._has_precomputed_row_lengths(),
                             'row_lengths' in expected_encodings)
            self.assertEqual(result._has_precomputed_value_rowids(),
                             'value_rowids' in expected_encodings)
            self.assertEqual(result._has_precomputed_nrows(), 'nrows'
                             in expected_encodings)
            self.assertEqual(result.uniform_row_length() is not None,
                             'uniform_row_length' in expected_encodings)
            for r in (x, y):
                if (r._has_precomputed_row_splits()
                        and result._has_precomputed_row_splits()):
                    self.assertAllEqual(r.row_splits(), result.row_splits())
                if (r._has_precomputed_row_lengths()
                        and result._has_precomputed_row_lengths()):
                    self.assertAllEqual(r.row_lengths(), result.row_lengths())
                if (r._has_precomputed_value_rowids()
                        and result._has_precomputed_value_rowids()):
                    self.assertAllEqual(r.value_rowids(),
                                        result.value_rowids())
                if r._has_precomputed_nrows(
                ) and result._has_precomputed_nrows():
                    self.assertAllEqual(r.nrows(), result.nrows())
                if (r.uniform_row_length() is not None
                        and result.uniform_row_length() is not None):
                    self.assertAllEqual(r.uniform_row_length(),
                                        result.uniform_row_length())

    def testMergePrecomputedEncodingsFastPaths(self):
        # Same object: x gets returned as-is.
        x = RowPartition.from_row_splits([0, 3, 8, 8])
        self.assertIs(x._merge_precomputed_encodings(x), x)

        # Same encoding tensor objects: x gets returned as-is.
        y = RowPartition.from_row_splits(x.row_splits(), validate=False)
        self.assertIs(x._merge_precomputed_encodings(y), x)

    def testMergePrecomputedEncodingsWithMatchingTensors(self):
        # The encoding tensors for `a` are a superset of the encoding tensors
        # for `b`, and where they overlap, they the same tensor objects.
        a = RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4])
        b = RowPartition.from_row_splits(a.row_splits(), validate=False)
        self.assertIs(a._merge_precomputed_encodings(b), a)
        self.assertIs(b._merge_precomputed_encodings(a), a)
        self.assertIsNot(a, b)

    @parameterized.named_parameters([
        dict(testcase_name='RowSplitMismatch',
             x=lambda: RowPartition.from_row_splits([0, 3, 8]),
             y=lambda: RowPartition.from_row_splits([0, 3, 8, 9]),
             message='incompatible row_splits'),
        dict(testcase_name='RowLengthMismatch',
             x=lambda: RowPartition.from_row_lengths([2, 0, 2]),
             y=lambda: RowPartition.from_row_lengths([2, 0, 2, 1]),
             message='incompatible row_splits'),  # row_splits is checked first
        dict(testcase_name='ValueRowIdMismatch',
             x=lambda: RowPartition.from_value_rowids([0, 3, 3, 4]),
             y=lambda: RowPartition.from_value_rowids([0, 3, 4]),
             message='incompatible value_rowids'),
    ])
    def testMergePrecomputedEncodingStaticErrors(self, x, y, message):
        if context.executing_eagerly():
            return
        # Errors that are caught by static shape checks.
        x = x()
        y = y()
        with self.assertRaisesRegex(ValueError, message):
            x._merge_precomputed_encodings(y).row_splits()
        with self.assertRaisesRegex(ValueError, message):
            y._merge_precomputed_encodings(x).row_splits()

    @parameterized.named_parameters([
        dict(testcase_name='NRowsMismatchAlt',
             x=lambda: RowPartition.from_uniform_row_length(
                 5, nrows=4, nvals=20),
             y=lambda: RowPartition.from_uniform_row_length(
                 5, nrows=3, nvals=15),
             message='incompatible nrows'),
        dict(testcase_name='UniformRowLengthMismatch',
             x=lambda: RowPartition.from_uniform_row_length(5, nvals=20),
             y=lambda: RowPartition.from_uniform_row_length(2, nvals=8),
             message='incompatible (nvals|uniform_row_length)'),
        dict(testcase_name='RowSplitMismatch',
             x=lambda: RowPartition.from_row_splits([0, 3, 8]),
             y=lambda: RowPartition.from_row_splits([0, 5, 8]),
             message='incompatible row_splits'),
        dict(testcase_name='RowLengthMismatch',
             x=lambda: RowPartition.from_row_lengths([2, 0, 2]),
             y=lambda: RowPartition.from_row_lengths([0, 0, 2]),
             message='incompatible (row_splits|nvals)'),
        dict(testcase_name='ValueRowIdMismatch',
             x=lambda: RowPartition.from_value_rowids([0, 3, 3]),
             y=lambda: RowPartition.from_value_rowids([0, 0, 3]),
             message='incompatible row_splits'),  # row_splits is checked first
    ])
    def testMergePrecomputedEncodingRuntimeErrors(self, x, y, message):
        # Errors that are caught by runtime value checks.
        x = x()
        y = y()
        with self.assertRaisesRegex(errors.InvalidArgumentError, message):
            self.evaluate(x._merge_precomputed_encodings(y).row_splits())
        with self.assertRaisesRegex(errors.InvalidArgumentError, message):
            self.evaluate(y._merge_precomputed_encodings(x).row_splits())

    @parameterized.named_parameters([
        # It throws the right error, but it still complains.
        dict(testcase_name='NRowsMismatch',
             x=lambda: RowPartition.from_uniform_row_length(5, nvals=20),
             y=lambda: RowPartition.from_uniform_row_length(5, nvals=15),
             message='incompatible nvals',
             emessage='incompatible nrows'),
    ])
    def testMergePrecomputedEncodingStaticErrors2(self, x, y, message,
                                                  emessage):
        # Message error and type varies depending upon eager execution.
        x = x()
        y = y()

        error_type = errors_impl.InvalidArgumentError
        expected_message = emessage if context.executing_eagerly() else message
        with self.assertRaisesRegex(error_type, expected_message):
            self.evaluate(x._merge_precomputed_encodings(y).row_splits())
        with self.assertRaisesRegex(error_type, expected_message):
            self.evaluate(y._merge_precomputed_encodings(x).row_splits())

    @parameterized.named_parameters([
        dict(testcase_name='from_uniform_row_length',
             x=lambda: RowPartition.from_uniform_row_length(5, nvals=20),
             expected=True),
        dict(testcase_name='from_row_splits',
             x=lambda: RowPartition.from_row_splits([0, 3, 8]),
             expected=False),
        dict(testcase_name='from_row_lengths',
             x=lambda: RowPartition.from_row_lengths([2, 0, 2]),
             expected=False),
        dict(testcase_name='from_row_lengths_uniform',
             x=lambda: RowPartition.from_row_lengths([3, 3, 3]),
             expected=False),
    ])
    def testIsUniform(self, x, expected):
        x = x()
        self.assertEqual(expected, x.is_uniform())

    @parameterized.named_parameters([
        dict(testcase_name='doc_example',
             x=lambda: RowPartition.from_row_lengths([3, 2, 0, 2]),
             expected=[0, 1, 2, 0, 1, 0, 1]),
        dict(testcase_name='from_uniform_row_length',
             x=lambda: RowPartition.from_uniform_row_length(4, nvals=12),
             expected=[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
        dict(testcase_name='from_row_splits',
             x=lambda: RowPartition.from_row_splits([0, 3, 8]),
             expected=[0, 1, 2, 0, 1, 2, 3, 4]),
    ])
    def testOffsetsInRows(self, x, expected):
        x = x()
        actual = x.offsets_in_rows()
        self.assertAllEqual(expected, actual)

    def testFromUniformRowLengthBugConvertToTensor(self):
        # This originally failed to run because nrows was dtypes.int32. I think
        # we may need to consider the semantics of the type of a RowPartition
        # if preferred_dtype is unspecified. Also, looking at convert_to_tensor:
        # dtype specifies the type of the output.
        # preferred_dtype/dtype_hint is a suggestion, and dtype_hint is the new
        # name.
        nrows = constant_op.constant(3, dtype=dtypes.int32)
        nvals = constant_op.constant(12, dtype=dtypes.int64)
        row_length = constant_op.constant(4, dtype=dtypes.int64)
        rp = RowPartition.from_uniform_row_length(row_length,
                                                  nvals=nvals,
                                                  nrows=nrows,
                                                  dtype=dtypes.int64)
        self.assertEqual(rp.nrows().dtype, dtypes.int64)

    def testFromUniformRowLengthNvalDynamic(self):
        # A key question is whether if nrows and uniform_row_length are known,
        # and nvals is given but not known statically, should we determine nvals?
        # TODO(martinz): Uncomment after nvals is fixed.
        # @def_function.function(
        #     input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
        # def foo(nvals):
        #   rp = RowPartition.from_uniform_row_length(12, nvals=nvals, nrows=3)
        #   nval_output = tensor_util.constant_value(rp.nvals())
        #   self.assertEqual(nval_output, 36)
        # foo(constant_op.constant(36, dtype=dtypes.int32))
        pass

    def testFromUniformRowLengthNvalDynamicNoValidate(self):
        # A key question is whether if nrows and uniform_row_length are known,
        # and nvals is given but not known statically, should we determine nvals?
        # TODO(martinz): Uncomment after nvals is fixed.
        # @def_function.function(
        #     input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
        # def foo(nvals):
        #   rp = RowPartition.from_uniform_row_length(12, nvals=nvals, nrows=3,
        #                                             validate=False)
        #   nval_output = tensor_util.constant_value(rp.nvals())
        #   self.assertEqual(nval_output, 36)
        # foo(constant_op.constant(36, dtype=dtypes.int32))
        pass

    def testFromUniformRowLengthNvalDynamicWrong(self):
        # A key question is whether if nrows and uniform_row_length are known,
        # and nvals is given but not known statically and WRONG,
        # what should we do? We add a check, but checks are only checked for
        # row_splits.
        @def_function.function(
            input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
        def foo(nvals):
            rp = RowPartition.from_uniform_row_length(12, nvals=nvals, nrows=3)
            return rp.nvals()

        with self.assertRaises(errors.InvalidArgumentError):
            nvals = foo(constant_op.constant(7, dtype=dtypes.int32))
            self.evaluate(nvals)

    def testFromUniformRowLengthNvalDynamicWrongRowSplits(self):
        # A key question is whether if nrows and uniform_row_length are known,
        # and nvals is given but not known statically and WRONG,
        # what should we do?
        # A key question is whether if nrows and uniform_row_length are known,
        # and nvals is given but not known statically and WRONG,
        # what should we do? We add a check, but checks are only checked for
        # row_splits.
        @def_function.function(
            input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
        def foo(nvals):
            rp = RowPartition.from_uniform_row_length(12, nvals=nvals, nrows=3)
            return rp.row_splits()

        with self.assertRaises(errors.InvalidArgumentError):
            rs = foo(constant_op.constant(7, dtype=dtypes.int32))
            self.evaluate(rs)

    def testFromUniformRowPartitionNrows(self):
        rp = RowPartition.from_uniform_row_length(3, nrows=4)
        self.assertAllEqual(4, rp.nrows())
        self.assertAllEqual(3, rp.uniform_row_length())
        self.assertAllEqual(12, rp.static_nvals)

    def testFromUniformRowPartitionNvalsStatic(self):
        rp = RowPartition.from_uniform_row_length(3, nvals=12)
        self.assertAllEqual(4, rp.static_nrows)
        self.assertAllEqual(3, rp.static_uniform_row_length)
        self.assertAllEqual(12, rp.static_nvals)

    def testFromUniformRowPartitionNvalsStaticNoValidate(self):
        rp = RowPartition.from_uniform_row_length(3,
                                                  nrows=4,
                                                  nvals=12,
                                                  validate=False)
        self.assertAllEqual(4, rp.static_nrows)
        self.assertAllEqual(3, rp.static_uniform_row_length)
        self.assertAllEqual(12, rp.static_nvals)

    def testFromUniformRowPartitionNvalsIs(self):
        # TODO(martinz): Uncomment after nvals is fixed.
        # nvals = constant_op.constant(12)
        # rp = RowPartition.from_uniform_row_length(3, nvals=nvals)
        # self.assertIs(rp.nvals(), nvals)
        pass

    def testFromUniformRowPartitionRowStartsStatic(self):
        rp = RowPartition.from_row_starts([0, 3, 6], nvals=12)
        self.assertAllEqual(12, rp.static_nvals)

    def testStaticNrows(self):
        rp = RowPartition.from_row_splits([0, 3, 4, 5])
        static_nrows = rp.static_nrows
        self.assertIsInstance(static_nrows, int)
        self.assertAllEqual(3, static_nrows)

    def testStaticNrowsUnknown(self):
        @def_function.function(
            input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
        def foo(rs):
            rp = RowPartition.from_row_splits(rs)
            static_nrows = rp.static_nrows
            self.assertIsNone(static_nrows)

        foo(array_ops.constant([0, 3, 4, 5], dtype=dtypes.int32))