def testFromRowLimits(self): row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64) rp = RowPartition.from_row_limits(row_limits, validate=False) self.assertEqual(rp.dtype, dtypes.int64) rp_row_limits = rp.row_limits() rp_row_splits = rp.row_splits() rp_nrows = rp.nrows() self.assertAllEqual(rp_nrows, 5) self.assertAllEqual(rp_row_limits, row_limits) self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])
def testClassDocStringExamples(self): # From section: "Component Tensors" rp = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8]) del rp # From section: "Alternative Row-Partitioning Schemes" rt1 = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) rt2 = RowPartition.from_row_lengths(row_lengths=[4, 0, 3, 1, 0]) rt3 = RowPartition.from_value_rowids( value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) rt4 = RowPartition.from_row_starts(row_starts=[0, 4, 4, 7, 8], nvals=8) rt5 = RowPartition.from_row_limits(row_limits=[4, 4, 7, 8, 8]) for rp in (rt1, rt2, rt3, rt4, rt5): self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8]) del rt1, rt2, rt3, rt4, rt5 # From section: "Multiple Ragged Dimensions" inner_rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) outer_rt = RowPartition.from_row_splits(row_splits=[0, 3, 3, 5]) del inner_rt, outer_rt
class RowPartitionTest(test_util.TensorFlowTestCase, parameterized.TestCase): #============================================================================= # RaggedTensor class docstring examples #============================================================================= def testClassDocStringExamples(self): # From section: "Component Tensors" rp = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8]) del rp # From section: "Alternative Row-Partitioning Schemes" rt1 = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) rt2 = RowPartition.from_row_lengths(row_lengths=[4, 0, 3, 1, 0]) rt3 = RowPartition.from_value_rowids( value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) rt4 = RowPartition.from_row_starts(row_starts=[0, 4, 4, 7, 8], nvals=8) rt5 = RowPartition.from_row_limits(row_limits=[4, 4, 7, 8, 8]) for rp in (rt1, rt2, rt3, rt4, rt5): self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8]) del rt1, rt2, rt3, rt4, rt5 # From section: "Multiple Ragged Dimensions" inner_rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) outer_rt = RowPartition.from_row_splits(row_splits=[0, 3, 3, 5]) del inner_rt, outer_rt #============================================================================= # RaggedTensor Constructor (private) #============================================================================= def testRaggedTensorConstruction(self): row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) rp = RowPartition( row_splits=row_splits, internal=row_partition._row_partition_factory_key) self.assertAllEqual(rp.row_splits(), [0, 2, 2, 5, 6, 7]) def testRaggedTensorConstructionErrors(self): row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) with self.assertRaisesRegex(ValueError, 'RaggedTensor constructor is private'): RowPartition(row_splits=row_splits) with self.assertRaisesRegex(TypeError, 'Row-partitioning argument must be a Tensor'): RowPartition( row_splits=[0, 2, 2, 5, 6, 7], internal=row_partition._row_partition_factory_key) with self.assertRaisesRegex(ValueError, r'Shape \(6, 1\) must have rank 1'): RowPartition( row_splits=array_ops.expand_dims(row_splits, 1), internal=row_partition._row_partition_factory_key) with self.assertRaisesRegex(TypeError, 'Cached value must be a Tensor or None.'): RowPartition( row_splits=row_splits, row_lengths=[2, 3, 4], internal=row_partition._row_partition_factory_key) #============================================================================= # RaggedTensor Factory Ops #============================================================================= def testFromValueRowIdsWithDerivedNRows(self): # nrows is known at graph creation time. value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) # TODO(martinz): add nrows rp = RowPartition.from_value_rowids(value_rowids, validate=False) self.assertEqual(rp.dtype, dtypes.int64) rp_row_splits = rp.row_splits() rp_value_rowids = rp.value_rowids() rp_nrows = rp.nrows() self.assertIs(rp_value_rowids, value_rowids) # value_rowids self.assertAllEqual(rp_value_rowids, value_rowids) self.assertAllEqual(rp_nrows, 5) self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7]) def testFromValueRowIdsWithDerivedNRowsDynamic(self): # nrows is not known at graph creation time. value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) value_rowids = array_ops.placeholder_with_default(value_rowids, shape=None) rp = RowPartition.from_value_rowids(value_rowids, validate=False) rp_value_rowids = rp.value_rowids() rp_nrows = rp.nrows() self.assertIs(rp_value_rowids, value_rowids) # value_rowids self.assertAllEqual(rp_value_rowids, value_rowids) self.assertAllEqual(rp_nrows, 5) def testFromValueRowIdsWithExplicitNRows(self): value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) nrows = constant_op.constant(7, dtypes.int64) rp = RowPartition.from_value_rowids(value_rowids, nrows, validate=False) rp_value_rowids = rp.value_rowids() rp_nrows = rp.nrows() rp_row_splits = rp.row_splits() self.assertIs(rp_value_rowids, value_rowids) # value_rowids self.assertIs(rp_nrows, nrows) # nrows self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7, 7, 7]) def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self): value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) nrows = constant_op.constant(5, dtypes.int64) rp = RowPartition.from_value_rowids(value_rowids, nrows, validate=False) rp_value_rowids = rp.value_rowids() rp_nrows = rp.nrows() rp_row_splits = rp.row_splits() self.assertIs(rp_value_rowids, value_rowids) # value_rowids self.assertIs(rp_nrows, nrows) # nrows self.assertAllEqual(rp_value_rowids, value_rowids) self.assertAllEqual(rp_nrows, nrows) self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7]) def testFromValueRowIdsWithEmptyValues(self): rp = RowPartition.from_value_rowids([]) rp_nrows = rp.nrows() self.assertEqual(rp.dtype, dtypes.int64) self.assertEqual(rp.value_rowids().shape.as_list(), [0]) self.assertAllEqual(rp_nrows, 0) def testFromRowSplits(self): row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) rp = RowPartition.from_row_splits(row_splits, validate=False) self.assertEqual(rp.dtype, dtypes.int64) rp_row_splits = rp.row_splits() rp_nrows = rp.nrows() self.assertIs(rp_row_splits, row_splits) self.assertAllEqual(rp_nrows, 5) def testFromRowSplitsWithDifferentSplitTypes(self): splits1 = [0, 2, 2, 5, 6, 7] splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64) splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32) splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32) rt1 = RowPartition.from_row_splits(splits1) rt2 = RowPartition.from_row_splits(splits2) rt3 = RowPartition.from_row_splits(splits3) rt4 = RowPartition.from_row_splits(splits4) rt5 = RowPartition.from_row_splits(splits5) self.assertEqual(rt1.row_splits().dtype, dtypes.int64) self.assertEqual(rt2.row_splits().dtype, dtypes.int64) self.assertEqual(rt3.row_splits().dtype, dtypes.int32) self.assertEqual(rt4.row_splits().dtype, dtypes.int64) self.assertEqual(rt5.row_splits().dtype, dtypes.int32) def testFromRowSplitsWithEmptySplits(self): err_msg = 'row_splits tensor may not be empty' with self.assertRaisesRegex(ValueError, err_msg): RowPartition.from_row_splits([]) def testFromRowStarts(self): nvals = constant_op.constant(7) row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64) rp = RowPartition.from_row_starts(row_starts, nvals, validate=False) self.assertEqual(rp.dtype, dtypes.int64) rp_row_starts = rp.row_starts() rp_row_splits = rp.row_splits() rp_nrows = rp.nrows() self.assertAllEqual(rp_nrows, 5) self.assertAllEqual(rp_row_starts, row_starts) self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7]) def testFromRowLimits(self): row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64) rp = RowPartition.from_row_limits(row_limits, validate=False) self.assertEqual(rp.dtype, dtypes.int64) rp_row_limits = rp.row_limits() rp_row_splits = rp.row_splits() rp_nrows = rp.nrows() self.assertAllEqual(rp_nrows, 5) self.assertAllEqual(rp_row_limits, row_limits) self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7]) def testFromRowLengths(self): row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64) rp = RowPartition.from_row_lengths(row_lengths, validate=False) self.assertEqual(rp.dtype, dtypes.int64) rp_row_lengths = rp.row_lengths() rp_nrows = rp.nrows() self.assertIs(rp_row_lengths, row_lengths) # nrows self.assertAllEqual(rp_nrows, 5) self.assertAllEqual(rp_row_lengths, row_lengths) def testFromUniformRowLength(self): nvals = 16 a1 = RowPartition.from_uniform_row_length( nvals=nvals, uniform_row_length=2) self.assertAllEqual(a1.uniform_row_length(), 2) self.assertAllEqual(a1.nrows(), 8) def testFromUniformRowLengthWithEmptyValues(self): a = RowPartition.from_uniform_row_length( nvals=0, uniform_row_length=0, nrows=10) self.assertEqual(self.evaluate(a.nvals()), 0) self.assertEqual(self.evaluate(a.nrows()), 10) def testFromUniformRowLengthWithPlaceholders1(self): nvals = array_ops.placeholder_with_default( constant_op.constant(6, dtype=dtypes.int64), None) rt1 = RowPartition.from_uniform_row_length( nvals=nvals, uniform_row_length=3) const_nvals1 = self.evaluate(rt1.nvals()) self.assertEqual(const_nvals1, 6) def testFromUniformRowLengthWithPlaceholders2(self): nvals = array_ops.placeholder_with_default(6, None) ph_rowlen = array_ops.placeholder_with_default(3, None) rt2 = RowPartition.from_uniform_row_length( nvals=nvals, uniform_row_length=ph_rowlen) const_nvals2 = self.evaluate(rt2.nvals()) self.assertEqual(const_nvals2, 6) def testFromValueRowIdsWithBadNRows(self): value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) nrows = constant_op.constant(5, dtypes.int64) with self.assertRaisesRegex(ValueError, r'Expected nrows >= 0; got -2'): RowPartition.from_value_rowids( value_rowids=array_ops.placeholder_with_default(value_rowids, None), nrows=-2) with self.assertRaisesRegex( ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, ' r'value_rowids\[-1\]=4'): RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=2) with self.assertRaisesRegex( ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, ' r'value_rowids\[-1\]=4'): RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=4) with self.assertRaisesRegex(ValueError, r'Shape \(7, 1\) must have rank 1'): RowPartition.from_value_rowids( value_rowids=array_ops.expand_dims(value_rowids, 1), nrows=nrows) with self.assertRaisesRegex(ValueError, r'Shape \(1,\) must have rank 0'): RowPartition.from_value_rowids( value_rowids=value_rowids, nrows=array_ops.expand_dims(nrows, 0)) #============================================================================= # RowPartition.__str__ #============================================================================= def testRowPartitionStr(self): row_splits = [0, 2, 5, 6, 6, 7] rp = RowPartition.from_row_splits(row_splits, validate=False) splits_type = 'int64' if context.executing_eagerly(): expected_repr = ('tf.RowPartition(row_splits=tf.Tensor([0 2 5 6 6 7], ' 'shape=(6,), dtype=int64))') else: expected_repr = ('tf.RowPartition(row_splits=' 'Tensor("RowPartitionFromRowSplits/row_splits:0", ' 'shape=(6,), dtype={}))').format(splits_type) self.assertEqual(repr(rp), expected_repr) self.assertEqual(str(rp), expected_repr) @parameterized.parameters([ # from_value_rowids { 'descr': 'bad rank for value_rowids', 'factory': RowPartition.from_value_rowids, 'value_rowids': [[1, 2], [3, 4]], 'nrows': 10 }, { 'descr': 'bad rank for nrows', 'factory': RowPartition.from_value_rowids, 'value_rowids': [1, 2, 3, 4], 'nrows': [10] }, { 'descr': 'negative value_rowid', 'factory': RowPartition.from_value_rowids, 'value_rowids': [-5, 2, 3, 4], 'nrows': 10 }, { 'descr': 'non-monotonic-increasing value_rowid', 'factory': RowPartition.from_value_rowids, 'value_rowids': [4, 3, 2, 1], 'nrows': 10 }, { 'descr': 'value_rowid > nrows', 'factory': RowPartition.from_value_rowids, 'value_rowids': [1, 2, 3, 4], 'nrows': 2 }, # from_row_splits { 'descr': 'bad rank for row_splits', 'factory': RowPartition.from_row_splits, 'row_splits': [[1, 2], [3, 4]] }, { 'descr': 'row_splits[0] != 0', 'factory': RowPartition.from_row_splits, 'row_splits': [2, 3, 4] }, { 'descr': 'non-monotonic-increasing row_splits', 'factory': RowPartition.from_row_splits, 'row_splits': [0, 3, 2, 4] }, # from_row_lengths { 'descr': 'bad rank for row_lengths', 'factory': RowPartition.from_row_lengths, 'row_lengths': [[1, 2], [1, 0]] }, { 'descr': 'negatve row_lengths', 'factory': RowPartition.from_row_lengths, 'row_lengths': [3, -1, 2] }, # from_row_starts { 'descr': 'bad rank for row_starts', 'factory': RowPartition.from_row_starts, 'nvals': 2, 'row_starts': [[1, 2], [3, 4]] }, { 'descr': 'row_starts[0] != 0', 'factory': RowPartition.from_row_starts, 'nvals': 5, 'row_starts': [2, 3, 4] }, { 'descr': 'non-monotonic-increasing row_starts', 'factory': RowPartition.from_row_starts, 'nvals': 4, 'row_starts': [0, 3, 2, 4] }, { 'descr': 'row_starts[0] > nvals', 'factory': RowPartition.from_row_starts, 'nvals': 4, 'row_starts': [0, 2, 3, 5] }, # from_row_limits { 'descr': 'bad rank for row_limits', 'factory': RowPartition.from_row_limits, 'row_limits': [[1, 2], [3, 4]] }, { 'descr': 'row_limits[0] < 0', 'factory': RowPartition.from_row_limits, 'row_limits': [-1, 3, 4] }, { 'descr': 'non-monotonic-increasing row_limits', 'factory': RowPartition.from_row_limits, 'row_limits': [0, 3, 2, 4] }, # from_uniform_row_length { 'descr': 'rowlen * nrows != nvals (1)', 'factory': RowPartition.from_uniform_row_length, 'nvals': 5, 'uniform_row_length': 3 }, { 'descr': 'rowlen * nrows != nvals (2)', 'factory': RowPartition.from_uniform_row_length, 'nvals': 5, 'uniform_row_length': 6 }, { 'descr': 'rowlen * nrows != nvals (3)', 'factory': RowPartition.from_uniform_row_length, 'nvals': 6, 'uniform_row_length': 3, 'nrows': 3 }, { 'descr': 'rowlen must be a scalar', 'factory': RowPartition.from_uniform_row_length, 'nvals': 4, 'uniform_row_length': [2] }, { 'descr': 'rowlen must be nonnegative', 'factory': RowPartition.from_uniform_row_length, 'nvals': 4, 'uniform_row_length': -1 }, ]) def testFactoryValidation(self, descr, factory, **kwargs): # When input tensors have shape information, some of these errors will be # detected statically. with self.assertRaises((errors.InvalidArgumentError, ValueError)): partition = factory(**kwargs) self.evaluate(partition.row_splits()) # Remove shape information (by wrapping tensors in placeholders), and check # that we detect the errors when the graph is run. if not context.executing_eagerly(): def wrap_arg(v): return array_ops.placeholder_with_default( constant_op.constant(v, dtype=dtypes.int64), tensor_shape.TensorShape(None)) kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items()) with self.assertRaises(errors.InvalidArgumentError): partition = factory(**kwargs) self.evaluate(partition.row_splits()) @parameterized.named_parameters([ ('FromRowSplits', lambda: RowPartition.from_row_splits([0, 2, 8]), ['row_splits']), ('FromRowLengths', lambda: RowPartition.from_row_lengths([3, 0, 8]), ['row_splits', 'row_lengths']), ('FromValueRowIds', lambda: RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4]), ['row_splits', 'value_rowids', 'row_lengths', 'nrows']), ('FromRowStarts', lambda: RowPartition.from_row_starts([0, 3, 7], nvals=10), ['row_splits']), ('FromRowLimits', lambda: RowPartition.from_row_limits([3, 7, 10]), ['row_splits']), ]) def testPrecomputedSplits(self, rp_factory, expected_encodings): rp = rp_factory() self.assertEqual(rp.has_precomputed_row_splits(), 'row_splits' in expected_encodings) self.assertEqual(rp.has_precomputed_row_lengths(), 'row_lengths' in expected_encodings) self.assertEqual(rp.has_precomputed_value_rowids(), 'value_rowids' in expected_encodings) self.assertEqual(rp.has_precomputed_nrows(), 'nrows' in expected_encodings) def testWithPrecomputedSplits(self): rp = RowPartition.from_row_splits([0, 2, 8]) rp_with_row_splits = rp.with_precomputed_row_splits() self.assertTrue(rp_with_row_splits.has_precomputed_row_splits()) self.assertFalse(rp.has_precomputed_row_lengths()) rp_with_row_lengths = rp.with_precomputed_row_lengths() self.assertTrue(rp_with_row_lengths.has_precomputed_row_lengths()) self.assertFalse(rp.has_precomputed_value_rowids()) rp_with_value_rowids = rp.with_precomputed_value_rowids() self.assertTrue(rp_with_value_rowids.has_precomputed_value_rowids()) self.assertFalse(rp.has_precomputed_nrows()) rp_with_nrows = rp.with_precomputed_nrows() self.assertTrue(rp_with_nrows.has_precomputed_nrows()) @parameterized.named_parameters([ dict( testcase_name='FromRowSplitsAndRowSplits', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_row_splits([0, 3, 8]), expected_encodings=['row_splits']), dict( testcase_name='FromRowSplitsAndUniformRowLength', x=lambda: RowPartition.from_row_splits([0, 3, 6]), y=lambda: RowPartition.from_uniform_row_length(3, nvals=6), expected_encodings=['row_splits', 'uniform_row_length', 'nrows']), dict( testcase_name='FromRowSplitsAndRowLengths', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_row_lengths([3, 5]), expected_encodings=['row_splits', 'row_lengths']), dict( testcase_name='FromRowSplitsAndValueRowIds', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_value_rowids([0, 0, 0, 1, 1, 1, 1, 1]), expected_encodings=[ 'row_splits', 'row_lengths', 'value_rowids', 'nrows' ]), dict( testcase_name='FromRowSplitsAndRowSplitsPlusNRows', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_row_splits([0, 3, 8]). with_precomputed_nrows(), expected_encodings=['row_splits', 'nrows']), ]) def testMergePrecomputedEncodings(self, x, y, expected_encodings): x = x() y = y() for validate in (True, False): result = x.merge_precomputed_encodings(y, validate) self.assertEqual(result.has_precomputed_row_splits(), 'row_splits' in expected_encodings) self.assertEqual(result.has_precomputed_row_lengths(), 'row_lengths' in expected_encodings) self.assertEqual(result.has_precomputed_value_rowids(), 'value_rowids' in expected_encodings) self.assertEqual(result.has_precomputed_nrows(), 'nrows' in expected_encodings) self.assertEqual(result.uniform_row_length() is not None, 'uniform_row_length' in expected_encodings) for r in (x, y): if (r.has_precomputed_row_splits() and result.has_precomputed_row_splits()): self.assertAllEqual(r.row_splits(), result.row_splits()) if (r.has_precomputed_row_lengths() and result.has_precomputed_row_lengths()): self.assertAllEqual(r.row_lengths(), result.row_lengths()) if (r.has_precomputed_value_rowids() and result.has_precomputed_value_rowids()): self.assertAllEqual(r.value_rowids(), result.value_rowids()) if r.has_precomputed_nrows() and result.has_precomputed_nrows(): self.assertAllEqual(r.nrows(), result.nrows()) if (r.uniform_row_length() is not None and result.uniform_row_length() is not None): self.assertAllEqual(r.uniform_row_length(), result.uniform_row_length()) def testMergePrecomputedEncodingsFastPaths(self): # Same object: x gets returned as-is. x = RowPartition.from_row_splits([0, 3, 8, 8]) self.assertIs(x.merge_precomputed_encodings(x), x) # Same encoding tensor objects: x gets returned as-is. y = RowPartition.from_row_splits(x.row_splits(), validate=False) self.assertIs(x.merge_precomputed_encodings(y), x) def testMergePrecomputedEncodingsWithMatchingTensors(self): # The encoding tensors for `a` are a superset of the encoding tensors # for `b`, and where they overlap, they the same tensor objects. a = RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4]) b = RowPartition.from_row_splits(a.row_splits(), validate=False) self.assertIs(a.merge_precomputed_encodings(b), a) self.assertIs(b.merge_precomputed_encodings(a), a) self.assertIsNot(a, b) @parameterized.named_parameters([ dict( testcase_name='RowSplitMismatch', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_row_splits([0, 3, 8, 9]), message='incompatible row_splits'), dict( testcase_name='RowLengthMismatch', x=lambda: RowPartition.from_row_lengths([2, 0, 2]), y=lambda: RowPartition.from_row_lengths([2, 0, 2, 1]), message='incompatible row_splits'), # row_splits is checked first dict( testcase_name='ValueRowIdMismatch', x=lambda: RowPartition.from_value_rowids([0, 3, 3, 4]), y=lambda: RowPartition.from_value_rowids([0, 3, 4]), message='incompatible value_rowids'), ]) def testMergePrecomputedEncodingStaticErrors(self, x, y, message): if context.executing_eagerly(): return # Errors that are caught by static shape checks. x = x() y = y() with self.assertRaisesRegex(ValueError, message): x.merge_precomputed_encodings(y).row_splits() with self.assertRaisesRegex(ValueError, message): y.merge_precomputed_encodings(x).row_splits() @parameterized.named_parameters([ dict( testcase_name='NRowsMismatch', x=lambda: RowPartition.from_uniform_row_length(5, nvals=20), y=lambda: RowPartition.from_uniform_row_length(5, nvals=15), message='incompatible nrows'), dict( testcase_name='UniformRowLengthMismatch', x=lambda: RowPartition.from_uniform_row_length(5, nvals=20), y=lambda: RowPartition.from_uniform_row_length(2, nvals=8), message='incompatible uniform_row_length'), dict( testcase_name='RowSplitMismatch', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_row_splits([0, 5, 8]), message='incompatible row_splits'), dict( testcase_name='RowLengthMismatch', x=lambda: RowPartition.from_row_lengths([2, 0, 2]), y=lambda: RowPartition.from_row_lengths([0, 0, 2]), message='incompatible row_splits'), # row_splits is checked first dict( testcase_name='ValueRowIdMismatch', x=lambda: RowPartition.from_value_rowids([0, 3, 3]), y=lambda: RowPartition.from_value_rowids([0, 0, 3]), message='incompatible row_splits'), # row_splits is checked first ]) def testMergePrecomputedEncodingRuntimeErrors(self, x, y, message): # Errors that are caught by runtime value checks. x = x() y = y() with self.assertRaisesRegex(errors.InvalidArgumentError, message): self.evaluate(x.merge_precomputed_encodings(y).row_splits()) with self.assertRaisesRegex(errors.InvalidArgumentError, message): self.evaluate(y.merge_precomputed_encodings(x).row_splits())
class RowPartitionTest(test_util.TensorFlowTestCase, parameterized.TestCase): #============================================================================= # RowPartition class docstring examples #============================================================================= def testClassDocStringExamples(self): # From section: "Component Tensors" rp = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8]) del rp # From section: "Alternative Row-Partitioning Schemes" rt1 = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) rt2 = RowPartition.from_row_lengths(row_lengths=[4, 0, 3, 1, 0]) rt3 = RowPartition.from_value_rowids( value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) rt4 = RowPartition.from_row_starts(row_starts=[0, 4, 4, 7, 8], nvals=8) rt5 = RowPartition.from_row_limits(row_limits=[4, 4, 7, 8, 8]) for rp in (rt1, rt2, rt3, rt4, rt5): self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8]) del rt1, rt2, rt3, rt4, rt5 # From section: "Multiple Ragged Dimensions" inner_rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) outer_rt = RowPartition.from_row_splits(row_splits=[0, 3, 3, 5]) del inner_rt, outer_rt #============================================================================= # RowPartition Constructor (private) #============================================================================= def testRowPartitionConstruction(self): row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) rp = RowPartition(row_splits=row_splits, internal=row_partition._row_partition_factory_key) self.assertAllEqual(rp.row_splits(), [0, 2, 2, 5, 6, 7]) def testRowPartitionConstructionErrors(self): row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) with self.assertRaisesRegex(ValueError, 'RowPartition constructor is private'): RowPartition(row_splits=row_splits) with self.assertRaisesRegex( TypeError, 'Row-partitioning argument must be a Tensor'): RowPartition(row_splits=[0, 2, 2, 5, 6, 7], internal=row_partition._row_partition_factory_key) with self.assertRaisesRegex(ValueError, r'Shape \(6, 1\) must have rank 1'): RowPartition(row_splits=array_ops.expand_dims(row_splits, 1), internal=row_partition._row_partition_factory_key) with self.assertRaisesRegex(TypeError, 'Cached value must be a Tensor or None.'): RowPartition(row_splits=row_splits, row_lengths=[2, 3, 4], internal=row_partition._row_partition_factory_key) with self.assertRaisesRegex(ValueError, 'Inconsistent dtype'): RowPartition(row_splits=constant_op.constant([0, 3], dtypes.int64), nrows=constant_op.constant(1, dtypes.int32), internal=row_partition._row_partition_factory_key) #============================================================================= # RowPartition Factory Ops #============================================================================= def testFromValueRowIdsWithDerivedNRows(self): # nrows is known at graph creation time. value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) # TODO(martinz): add nrows rp = RowPartition.from_value_rowids(value_rowids, validate=False) self.assertEqual(rp.dtype, dtypes.int64) rp_row_splits = rp.row_splits() rp_value_rowids = rp.value_rowids() rp_nrows = rp.nrows() self.assertIs(rp_value_rowids, value_rowids) # value_rowids self.assertAllEqual(rp_value_rowids, value_rowids) self.assertAllEqual(rp_nrows, 5) self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7]) def testFromValueRowIdsWithDerivedNRowsDynamic(self): # nrows is not known at graph creation time. value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) value_rowids = array_ops.placeholder_with_default(value_rowids, shape=None) rp = RowPartition.from_value_rowids(value_rowids, validate=False) rp_value_rowids = rp.value_rowids() rp_nrows = rp.nrows() self.assertIs(rp_value_rowids, value_rowids) # value_rowids self.assertAllEqual(rp_value_rowids, value_rowids) self.assertAllEqual(rp_nrows, 5) def testFromValueRowIdsWithExplicitNRows(self): value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) nrows = constant_op.constant(7, dtypes.int64) rp = RowPartition.from_value_rowids(value_rowids, nrows, validate=False) rp_value_rowids = rp.value_rowids() rp_nrows = rp.nrows() rp_row_splits = rp.row_splits() self.assertIs(rp_value_rowids, value_rowids) # value_rowids self.assertIs(rp_nrows, nrows) # nrows self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7, 7, 7]) def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self): value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) nrows = constant_op.constant(5, dtypes.int64) rp = RowPartition.from_value_rowids(value_rowids, nrows, validate=False) rp_value_rowids = rp.value_rowids() rp_nrows = rp.nrows() rp_row_splits = rp.row_splits() self.assertIs(rp_value_rowids, value_rowids) # value_rowids self.assertIs(rp_nrows, nrows) # nrows self.assertAllEqual(rp_value_rowids, value_rowids) self.assertAllEqual(rp_nrows, nrows) self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7]) def testFromValueRowIdsWithEmptyValues(self): rp = RowPartition.from_value_rowids([]) rp_nrows = rp.nrows() self.assertEqual(rp.dtype, dtypes.int64) self.assertEqual(rp.value_rowids().shape.as_list(), [0]) self.assertAllEqual(rp_nrows, 0) def testFromRowSplits(self): row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) rp = RowPartition.from_row_splits(row_splits, validate=False) self.assertEqual(rp.dtype, dtypes.int64) rp_row_splits = rp.row_splits() rp_nrows = rp.nrows() self.assertIs(rp_row_splits, row_splits) self.assertAllEqual(rp_nrows, 5) def testFromRowSplitsWithDifferentSplitTypes(self): splits1 = [0, 2, 2, 5, 6, 7] splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64) splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32) splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32) rt1 = RowPartition.from_row_splits(splits1) rt2 = RowPartition.from_row_splits(splits2) rt3 = RowPartition.from_row_splits(splits3) rt4 = RowPartition.from_row_splits(splits4) rt5 = RowPartition.from_row_splits(splits5) self.assertEqual(rt1.row_splits().dtype, dtypes.int64) self.assertEqual(rt2.row_splits().dtype, dtypes.int64) self.assertEqual(rt3.row_splits().dtype, dtypes.int32) self.assertEqual(rt4.row_splits().dtype, dtypes.int64) self.assertEqual(rt5.row_splits().dtype, dtypes.int32) def testFromRowSplitsWithEmptySplits(self): err_msg = 'row_splits tensor may not be empty' with self.assertRaisesRegex(ValueError, err_msg): RowPartition.from_row_splits([]) def testFromRowStarts(self): nvals = constant_op.constant(7) row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64) rp = RowPartition.from_row_starts(row_starts, nvals, validate=False) self.assertEqual(rp.dtype, dtypes.int64) rp_row_starts = rp.row_starts() rp_row_splits = rp.row_splits() rp_nrows = rp.nrows() self.assertAllEqual(rp_nrows, 5) self.assertAllEqual(rp_row_starts, row_starts) self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7]) def testFromRowLimits(self): row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64) rp = RowPartition.from_row_limits(row_limits, validate=False) self.assertEqual(rp.dtype, dtypes.int64) rp_row_limits = rp.row_limits() rp_row_splits = rp.row_splits() rp_nrows = rp.nrows() self.assertAllEqual(rp_nrows, 5) self.assertAllEqual(rp_row_limits, row_limits) self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7]) def testFromRowLengths(self): row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64) rp = RowPartition.from_row_lengths(row_lengths, validate=False) self.assertEqual(rp.dtype, dtypes.int64) rp_row_lengths = rp.row_lengths() rp_nrows = rp.nrows() self.assertIs(rp_row_lengths, row_lengths) # nrows self.assertAllEqual(rp_nrows, 5) self.assertAllEqual(rp_row_lengths, row_lengths) def testFromUniformRowLength(self): nvals = 16 a1 = RowPartition.from_uniform_row_length(nvals=nvals, uniform_row_length=2) self.assertAllEqual(a1.uniform_row_length(), 2) self.assertAllEqual(a1.nrows(), 8) def testFromUniformRowLengthWithEmptyValues(self): a = RowPartition.from_uniform_row_length(nvals=0, uniform_row_length=0, nrows=10) self.assertEqual(self.evaluate(a.nvals()), 0) self.assertEqual(self.evaluate(a.nrows()), 10) def testFromUniformRowLengthWithPlaceholders1(self): nvals = array_ops.placeholder_with_default( constant_op.constant(6, dtype=dtypes.int64), None) rt1 = RowPartition.from_uniform_row_length(nvals=nvals, uniform_row_length=3) const_nvals1 = self.evaluate(rt1.nvals()) self.assertEqual(const_nvals1, 6) def testFromUniformRowLengthWithPlaceholders2(self): nvals = array_ops.placeholder_with_default(6, None) ph_rowlen = array_ops.placeholder_with_default(3, None) rt2 = RowPartition.from_uniform_row_length( nvals=nvals, uniform_row_length=ph_rowlen) const_nvals2 = self.evaluate(rt2.nvals()) self.assertEqual(const_nvals2, 6) def testFromValueRowIdsWithBadNRows(self): value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) nrows = constant_op.constant(5, dtypes.int64) with self.assertRaisesRegex(ValueError, r'Expected nrows >= 0; got -2'): RowPartition.from_value_rowids( value_rowids=array_ops.placeholder_with_default( value_rowids, None), nrows=-2) with self.assertRaisesRegex( ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, ' r'value_rowids\[-1\]=4'): RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=2) with self.assertRaisesRegex( ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, ' r'value_rowids\[-1\]=4'): RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=4) with self.assertRaisesRegex(ValueError, r'Shape \(7, 1\) must have rank 1'): RowPartition.from_value_rowids(value_rowids=array_ops.expand_dims( value_rowids, 1), nrows=nrows) with self.assertRaisesRegex(ValueError, r'Shape \(1,\) must have rank 0'): RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=array_ops.expand_dims( nrows, 0)) #============================================================================= # RowPartition.__str__ #============================================================================= def testRowPartitionStr(self): row_splits = [0, 2, 5, 6, 6, 7] rp = RowPartition.from_row_splits(row_splits, validate=False) if context.executing_eagerly(): expected_repr = 'tf.RowPartition(row_splits=[0 2 5 6 6 7])' else: expected_repr = ( 'tf.RowPartition(row_splits=' 'Tensor("RowPartitionFromRowSplits/row_splits:0", ' 'shape=(6,), dtype=int64))') self.assertEqual(repr(rp), expected_repr) self.assertEqual(str(rp), expected_repr) def testRowPartitionStrUniformRowLength(self): rp = RowPartition.from_uniform_row_length(5, nvals=10, nrows=2) if context.executing_eagerly(): expected_repr = ('tf.RowPartition(nrows=2, uniform_row_length=5)') else: expected_repr = ( 'tf.RowPartition(nrows=' 'Tensor("RowPartitionFromUniformRowLength/' 'nrows:0", shape=(), dtype=int64), ' 'uniform_row_length=Tensor("RowPartitionFromUniformRowLength/' 'uniform_row_length:0", shape=(), dtype=int64))') self.assertEqual(repr(rp), expected_repr) self.assertEqual(str(rp), expected_repr) @parameterized.parameters([ # from_value_rowids { 'descr': 'bad rank for value_rowids', 'factory': RowPartition.from_value_rowids, 'value_rowids': [[1, 2], [3, 4]], 'nrows': 10 }, { 'descr': 'bad rank for nrows', 'factory': RowPartition.from_value_rowids, 'value_rowids': [1, 2, 3, 4], 'nrows': [10] }, { 'descr': 'negative value_rowid', 'factory': RowPartition.from_value_rowids, 'value_rowids': [-5, 2, 3, 4], 'nrows': 10 }, { 'descr': 'non-monotonic-increasing value_rowid', 'factory': RowPartition.from_value_rowids, 'value_rowids': [4, 3, 2, 1], 'nrows': 10 }, { 'descr': 'value_rowid > nrows', 'factory': RowPartition.from_value_rowids, 'value_rowids': [1, 2, 3, 4], 'nrows': 2 }, # from_row_splits { 'descr': 'bad rank for row_splits', 'factory': RowPartition.from_row_splits, 'row_splits': [[1, 2], [3, 4]] }, { 'descr': 'row_splits[0] != 0', 'factory': RowPartition.from_row_splits, 'row_splits': [2, 3, 4] }, { 'descr': 'non-monotonic-increasing row_splits', 'factory': RowPartition.from_row_splits, 'row_splits': [0, 3, 2, 4] }, # from_row_lengths { 'descr': 'bad rank for row_lengths', 'factory': RowPartition.from_row_lengths, 'row_lengths': [[1, 2], [1, 0]] }, { 'descr': 'negatve row_lengths', 'factory': RowPartition.from_row_lengths, 'row_lengths': [3, -1, 2] }, # from_row_starts { 'descr': 'bad rank for row_starts', 'factory': RowPartition.from_row_starts, 'nvals': 2, 'row_starts': [[1, 2], [3, 4]] }, { 'descr': 'row_starts[0] != 0', 'factory': RowPartition.from_row_starts, 'nvals': 5, 'row_starts': [2, 3, 4] }, { 'descr': 'non-monotonic-increasing row_starts', 'factory': RowPartition.from_row_starts, 'nvals': 4, 'row_starts': [0, 3, 2, 4] }, { 'descr': 'row_starts[0] > nvals', 'factory': RowPartition.from_row_starts, 'nvals': 4, 'row_starts': [0, 2, 3, 5] }, # from_row_limits { 'descr': 'bad rank for row_limits', 'factory': RowPartition.from_row_limits, 'row_limits': [[1, 2], [3, 4]] }, { 'descr': 'row_limits[0] < 0', 'factory': RowPartition.from_row_limits, 'row_limits': [-1, 3, 4] }, { 'descr': 'non-monotonic-increasing row_limits', 'factory': RowPartition.from_row_limits, 'row_limits': [0, 3, 2, 4] }, # from_uniform_row_length { 'descr': 'rowlen * nrows != nvals (1)', 'factory': RowPartition.from_uniform_row_length, 'nvals': 5, 'uniform_row_length': 3 }, { 'descr': 'rowlen * nrows != nvals (2)', 'factory': RowPartition.from_uniform_row_length, 'nvals': 5, 'uniform_row_length': 6 }, { 'descr': 'rowlen * nrows != nvals (3)', 'factory': RowPartition.from_uniform_row_length, 'nvals': 6, 'uniform_row_length': 3, 'nrows': 3 }, { 'descr': 'rowlen must be a scalar', 'factory': RowPartition.from_uniform_row_length, 'nvals': 4, 'uniform_row_length': [2] }, { 'descr': 'rowlen must be nonnegative', 'factory': RowPartition.from_uniform_row_length, 'nvals': 4, 'uniform_row_length': -1 }, ]) def testFactoryValidation(self, descr, factory, **kwargs): # When input tensors have shape information, some of these errors will be # detected statically. with self.assertRaises((errors.InvalidArgumentError, ValueError)): partition = factory(**kwargs) self.evaluate(partition.row_splits()) # Remove shape information (by wrapping tensors in placeholders), and check # that we detect the errors when the graph is run. if not context.executing_eagerly(): def wrap_arg(v): return array_ops.placeholder_with_default( constant_op.constant(v, dtype=dtypes.int64), tensor_shape.TensorShape(None)) kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items()) with self.assertRaises(errors.InvalidArgumentError): partition = factory(**kwargs) self.evaluate(partition.row_splits()) @parameterized.named_parameters([ ('FromRowSplits', lambda: RowPartition.from_row_splits([0, 2, 8]), ['row_splits']), ('FromRowLengths', lambda: RowPartition.from_row_lengths([3, 0, 8]), ['row_splits', 'row_lengths']), ('FromValueRowIds', lambda: RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4]), ['row_splits', 'value_rowids', 'row_lengths', 'nrows']), ('FromRowStarts', lambda: RowPartition.from_row_starts([0, 3, 7], nvals=10), ['row_splits']), ('FromRowLimits', lambda: RowPartition.from_row_limits([3, 7, 10]), ['row_splits']), ]) def testPrecomputedSplits(self, rp_factory, expected_encodings): rp = rp_factory() self.assertEqual(rp._has_precomputed_row_splits(), 'row_splits' in expected_encodings) self.assertEqual(rp._has_precomputed_row_lengths(), 'row_lengths' in expected_encodings) self.assertEqual(rp._has_precomputed_value_rowids(), 'value_rowids' in expected_encodings) self.assertEqual(rp._has_precomputed_nrows(), 'nrows' in expected_encodings) def testWithPrecomputedSplits(self): rp = RowPartition.from_row_splits([0, 2, 8]) rp_with_row_splits = rp._with_precomputed_row_splits() self.assertTrue(rp_with_row_splits._has_precomputed_row_splits()) self.assertFalse(rp._has_precomputed_row_lengths()) rp_with_row_lengths = rp._with_precomputed_row_lengths() self.assertTrue(rp_with_row_lengths._has_precomputed_row_lengths()) self.assertFalse(rp._has_precomputed_value_rowids()) rp_with_value_rowids = rp._with_precomputed_value_rowids() self.assertTrue(rp_with_value_rowids._has_precomputed_value_rowids()) self.assertFalse(rp._has_precomputed_nrows()) rp_with_nrows = rp._with_precomputed_nrows() self.assertTrue(rp_with_nrows._has_precomputed_nrows()) self.assertFalse(rp._has_precomputed_nvals()) rp_with_nvals = rp._with_precomputed_nvals() self.assertTrue(rp_with_nvals._has_precomputed_nvals()) @parameterized.named_parameters([ dict(testcase_name='FromRowSplitsAndRowSplits', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_row_splits([0, 3, 8]), expected_encodings=['row_splits']), dict(testcase_name='FromRowSplitsAndUniformRowLength', x=lambda: RowPartition.from_row_splits([0, 3, 6]), y=lambda: RowPartition.from_uniform_row_length(3, nvals=6), expected_encodings=['row_splits', 'uniform_row_length', 'nrows']), dict(testcase_name='FromRowSplitsAndRowLengths', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_row_lengths([3, 5]), expected_encodings=['row_splits', 'row_lengths']), dict( testcase_name='FromRowSplitsAndValueRowIds', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_value_rowids([0, 0, 0, 1, 1, 1, 1, 1]), expected_encodings=[ 'row_splits', 'row_lengths', 'value_rowids', 'nrows' ]), dict(testcase_name='FromRowSplitsAndRowSplitsPlusNRows', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_row_splits([0, 3, 8]). _with_precomputed_nrows(), expected_encodings=['row_splits', 'nrows']), ]) def testMergePrecomputedEncodings(self, x, y, expected_encodings): x = x() y = y() for validate in (True, False): result = x._merge_precomputed_encodings(y, validate) self.assertEqual(result._has_precomputed_row_splits(), 'row_splits' in expected_encodings) self.assertEqual(result._has_precomputed_row_lengths(), 'row_lengths' in expected_encodings) self.assertEqual(result._has_precomputed_value_rowids(), 'value_rowids' in expected_encodings) self.assertEqual(result._has_precomputed_nrows(), 'nrows' in expected_encodings) self.assertEqual(result.uniform_row_length() is not None, 'uniform_row_length' in expected_encodings) for r in (x, y): if (r._has_precomputed_row_splits() and result._has_precomputed_row_splits()): self.assertAllEqual(r.row_splits(), result.row_splits()) if (r._has_precomputed_row_lengths() and result._has_precomputed_row_lengths()): self.assertAllEqual(r.row_lengths(), result.row_lengths()) if (r._has_precomputed_value_rowids() and result._has_precomputed_value_rowids()): self.assertAllEqual(r.value_rowids(), result.value_rowids()) if r._has_precomputed_nrows( ) and result._has_precomputed_nrows(): self.assertAllEqual(r.nrows(), result.nrows()) if (r.uniform_row_length() is not None and result.uniform_row_length() is not None): self.assertAllEqual(r.uniform_row_length(), result.uniform_row_length()) def testMergePrecomputedEncodingsFastPaths(self): # Same object: x gets returned as-is. x = RowPartition.from_row_splits([0, 3, 8, 8]) self.assertIs(x._merge_precomputed_encodings(x), x) # Same encoding tensor objects: x gets returned as-is. y = RowPartition.from_row_splits(x.row_splits(), validate=False) self.assertIs(x._merge_precomputed_encodings(y), x) def testMergePrecomputedEncodingsWithMatchingTensors(self): # The encoding tensors for `a` are a superset of the encoding tensors # for `b`, and where they overlap, they the same tensor objects. a = RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4]) b = RowPartition.from_row_splits(a.row_splits(), validate=False) self.assertIs(a._merge_precomputed_encodings(b), a) self.assertIs(b._merge_precomputed_encodings(a), a) self.assertIsNot(a, b) @parameterized.named_parameters([ dict(testcase_name='RowSplitMismatch', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_row_splits([0, 3, 8, 9]), message='incompatible row_splits'), dict(testcase_name='RowLengthMismatch', x=lambda: RowPartition.from_row_lengths([2, 0, 2]), y=lambda: RowPartition.from_row_lengths([2, 0, 2, 1]), message='incompatible row_splits'), # row_splits is checked first dict(testcase_name='ValueRowIdMismatch', x=lambda: RowPartition.from_value_rowids([0, 3, 3, 4]), y=lambda: RowPartition.from_value_rowids([0, 3, 4]), message='incompatible value_rowids'), ]) def testMergePrecomputedEncodingStaticErrors(self, x, y, message): if context.executing_eagerly(): return # Errors that are caught by static shape checks. x = x() y = y() with self.assertRaisesRegex(ValueError, message): x._merge_precomputed_encodings(y).row_splits() with self.assertRaisesRegex(ValueError, message): y._merge_precomputed_encodings(x).row_splits() @parameterized.named_parameters([ dict(testcase_name='NRowsMismatchAlt', x=lambda: RowPartition.from_uniform_row_length( 5, nrows=4, nvals=20), y=lambda: RowPartition.from_uniform_row_length( 5, nrows=3, nvals=15), message='incompatible nrows'), dict(testcase_name='UniformRowLengthMismatch', x=lambda: RowPartition.from_uniform_row_length(5, nvals=20), y=lambda: RowPartition.from_uniform_row_length(2, nvals=8), message='incompatible (nvals|uniform_row_length)'), dict(testcase_name='RowSplitMismatch', x=lambda: RowPartition.from_row_splits([0, 3, 8]), y=lambda: RowPartition.from_row_splits([0, 5, 8]), message='incompatible row_splits'), dict(testcase_name='RowLengthMismatch', x=lambda: RowPartition.from_row_lengths([2, 0, 2]), y=lambda: RowPartition.from_row_lengths([0, 0, 2]), message='incompatible (row_splits|nvals)'), dict(testcase_name='ValueRowIdMismatch', x=lambda: RowPartition.from_value_rowids([0, 3, 3]), y=lambda: RowPartition.from_value_rowids([0, 0, 3]), message='incompatible row_splits'), # row_splits is checked first ]) def testMergePrecomputedEncodingRuntimeErrors(self, x, y, message): # Errors that are caught by runtime value checks. x = x() y = y() with self.assertRaisesRegex(errors.InvalidArgumentError, message): self.evaluate(x._merge_precomputed_encodings(y).row_splits()) with self.assertRaisesRegex(errors.InvalidArgumentError, message): self.evaluate(y._merge_precomputed_encodings(x).row_splits()) @parameterized.named_parameters([ # It throws the right error, but it still complains. dict(testcase_name='NRowsMismatch', x=lambda: RowPartition.from_uniform_row_length(5, nvals=20), y=lambda: RowPartition.from_uniform_row_length(5, nvals=15), message='incompatible nvals', emessage='incompatible nrows'), ]) def testMergePrecomputedEncodingStaticErrors2(self, x, y, message, emessage): # Message error and type varies depending upon eager execution. x = x() y = y() error_type = errors_impl.InvalidArgumentError expected_message = emessage if context.executing_eagerly() else message with self.assertRaisesRegex(error_type, expected_message): self.evaluate(x._merge_precomputed_encodings(y).row_splits()) with self.assertRaisesRegex(error_type, expected_message): self.evaluate(y._merge_precomputed_encodings(x).row_splits()) @parameterized.named_parameters([ dict(testcase_name='from_uniform_row_length', x=lambda: RowPartition.from_uniform_row_length(5, nvals=20), expected=True), dict(testcase_name='from_row_splits', x=lambda: RowPartition.from_row_splits([0, 3, 8]), expected=False), dict(testcase_name='from_row_lengths', x=lambda: RowPartition.from_row_lengths([2, 0, 2]), expected=False), dict(testcase_name='from_row_lengths_uniform', x=lambda: RowPartition.from_row_lengths([3, 3, 3]), expected=False), ]) def testIsUniform(self, x, expected): x = x() self.assertEqual(expected, x.is_uniform()) @parameterized.named_parameters([ dict(testcase_name='doc_example', x=lambda: RowPartition.from_row_lengths([3, 2, 0, 2]), expected=[0, 1, 2, 0, 1, 0, 1]), dict(testcase_name='from_uniform_row_length', x=lambda: RowPartition.from_uniform_row_length(4, nvals=12), expected=[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), dict(testcase_name='from_row_splits', x=lambda: RowPartition.from_row_splits([0, 3, 8]), expected=[0, 1, 2, 0, 1, 2, 3, 4]), ]) def testOffsetsInRows(self, x, expected): x = x() actual = x.offsets_in_rows() self.assertAllEqual(expected, actual) def testFromUniformRowLengthBugConvertToTensor(self): # This originally failed to run because nrows was dtypes.int32. I think # we may need to consider the semantics of the type of a RowPartition # if preferred_dtype is unspecified. Also, looking at convert_to_tensor: # dtype specifies the type of the output. # preferred_dtype/dtype_hint is a suggestion, and dtype_hint is the new # name. nrows = constant_op.constant(3, dtype=dtypes.int32) nvals = constant_op.constant(12, dtype=dtypes.int64) row_length = constant_op.constant(4, dtype=dtypes.int64) rp = RowPartition.from_uniform_row_length(row_length, nvals=nvals, nrows=nrows, dtype=dtypes.int64) self.assertEqual(rp.nrows().dtype, dtypes.int64) def testFromUniformRowLengthNvalDynamic(self): # A key question is whether if nrows and uniform_row_length are known, # and nvals is given but not known statically, should we determine nvals? # TODO(martinz): Uncomment after nvals is fixed. # @def_function.function( # input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) # def foo(nvals): # rp = RowPartition.from_uniform_row_length(12, nvals=nvals, nrows=3) # nval_output = tensor_util.constant_value(rp.nvals()) # self.assertEqual(nval_output, 36) # foo(constant_op.constant(36, dtype=dtypes.int32)) pass def testFromUniformRowLengthNvalDynamicNoValidate(self): # A key question is whether if nrows and uniform_row_length are known, # and nvals is given but not known statically, should we determine nvals? # TODO(martinz): Uncomment after nvals is fixed. # @def_function.function( # input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) # def foo(nvals): # rp = RowPartition.from_uniform_row_length(12, nvals=nvals, nrows=3, # validate=False) # nval_output = tensor_util.constant_value(rp.nvals()) # self.assertEqual(nval_output, 36) # foo(constant_op.constant(36, dtype=dtypes.int32)) pass def testFromUniformRowLengthNvalDynamicWrong(self): # A key question is whether if nrows and uniform_row_length are known, # and nvals is given but not known statically and WRONG, # what should we do? We add a check, but checks are only checked for # row_splits. @def_function.function( input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) def foo(nvals): rp = RowPartition.from_uniform_row_length(12, nvals=nvals, nrows=3) return rp.nvals() with self.assertRaises(errors.InvalidArgumentError): nvals = foo(constant_op.constant(7, dtype=dtypes.int32)) self.evaluate(nvals) def testFromUniformRowLengthNvalDynamicWrongRowSplits(self): # A key question is whether if nrows and uniform_row_length are known, # and nvals is given but not known statically and WRONG, # what should we do? # A key question is whether if nrows and uniform_row_length are known, # and nvals is given but not known statically and WRONG, # what should we do? We add a check, but checks are only checked for # row_splits. @def_function.function( input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) def foo(nvals): rp = RowPartition.from_uniform_row_length(12, nvals=nvals, nrows=3) return rp.row_splits() with self.assertRaises(errors.InvalidArgumentError): rs = foo(constant_op.constant(7, dtype=dtypes.int32)) self.evaluate(rs) def testFromUniformRowPartitionNrows(self): rp = RowPartition.from_uniform_row_length(3, nrows=4) self.assertAllEqual(4, rp.nrows()) self.assertAllEqual(3, rp.uniform_row_length()) self.assertAllEqual(12, rp.static_nvals) def testFromUniformRowPartitionNvalsStatic(self): rp = RowPartition.from_uniform_row_length(3, nvals=12) self.assertAllEqual(4, rp.static_nrows) self.assertAllEqual(3, rp.static_uniform_row_length) self.assertAllEqual(12, rp.static_nvals) def testFromUniformRowPartitionNvalsStaticNoValidate(self): rp = RowPartition.from_uniform_row_length(3, nrows=4, nvals=12, validate=False) self.assertAllEqual(4, rp.static_nrows) self.assertAllEqual(3, rp.static_uniform_row_length) self.assertAllEqual(12, rp.static_nvals) def testFromUniformRowPartitionNvalsIs(self): # TODO(martinz): Uncomment after nvals is fixed. # nvals = constant_op.constant(12) # rp = RowPartition.from_uniform_row_length(3, nvals=nvals) # self.assertIs(rp.nvals(), nvals) pass def testFromUniformRowPartitionRowStartsStatic(self): rp = RowPartition.from_row_starts([0, 3, 6], nvals=12) self.assertAllEqual(12, rp.static_nvals) def testStaticNrows(self): rp = RowPartition.from_row_splits([0, 3, 4, 5]) static_nrows = rp.static_nrows self.assertIsInstance(static_nrows, int) self.assertAllEqual(3, static_nrows) def testStaticNrowsUnknown(self): @def_function.function( input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) def foo(rs): rp = RowPartition.from_row_splits(rs) static_nrows = rp.static_nrows self.assertIsNone(static_nrows) foo(array_ops.constant([0, 3, 4, 5], dtype=dtypes.int32))