def test2DRaggedTensorWithOneRaggedDimension(self): values = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] rt1 = ragged.from_row_splits(values, [0, 2, 5, 6, 6, 7]) rt2 = ragged.from_row_splits(values, [0, 7]) rt3 = ragged.from_row_splits(values, [0, 0, 7, 7]) self.assertEqual(self.evaluate(ragged.bounding_shape(rt1)).tolist(), [5, 3]) self.assertEqual(self.evaluate(ragged.bounding_shape(rt2)).tolist(), [1, 7]) self.assertEqual(self.evaluate(ragged.bounding_shape(rt3)).tolist(), [3, 7])
def test3DRaggedTensorWithOneRaggedDimension(self): values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]] rt1 = ragged.from_row_splits(values, [0, 2, 5, 6, 6, 7]) rt2 = ragged.from_row_splits(values, [0, 7]) rt3 = ragged.from_row_splits(values, [0, 0, 7, 7]) with self.test_session(): self.assertEqual(ragged.bounding_shape(rt1).eval().tolist(), [5, 3, 2]) self.assertEqual(ragged.bounding_shape(rt2).eval().tolist(), [1, 7, 2]) self.assertEqual(ragged.bounding_shape(rt3).eval().tolist(), [3, 7, 2])
def test3DRaggedTensorWithOneRaggedDimension(self): values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]] rt1 = ragged.from_row_splits(values, [0, 2, 5, 6, 6, 7]) rt2 = ragged.from_row_splits(values, [0, 7]) rt3 = ragged.from_row_splits(values, [0, 0, 7, 7]) self.assertEqual( self.evaluate(ragged.bounding_shape(rt1)).tolist(), [5, 3, 2]) self.assertEqual( self.evaluate(ragged.bounding_shape(rt2)).tolist(), [1, 7, 2]) self.assertEqual( self.evaluate(ragged.bounding_shape(rt3)).tolist(), [3, 7, 2])
def testRaggedTensorSplitsMismatchErrorAtRuntime(self): splits1 = array_ops.placeholder_with_default( constant_op.constant([0, 3, 3, 5], dtypes.int64), None) splits2 = array_ops.placeholder_with_default( constant_op.constant([0, 1, 3, 5], dtypes.int64), None) x = ragged.from_row_splits([3, 1, 4, 1, 5], splits1) y = ragged.from_row_splits([1, 2, 3, 4, 5], splits2) result = ragged.map_inner_values(math_ops.add, x, y) with self.test_session(): self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[Inputs must have identical ragged splits\] ' r'\[Condition x == y did not hold element-wise:\].*', result.eval)
def testShapeMismatchError2(self): rt = ragged.constant([ [[111, 112, 113, 114], [121]], # row 0 [], # row 1 [[], [321, 322], [331]], # row 2 [[411, 412]] # row 3 ]) # pyformat: disable segment_ids = ragged.constant([[1, 2], [1], [1, 1, 2], [2]]) # Error is raised at graph-building time if we can detect it then. self.assertRaisesRegexp( errors.InvalidArgumentError, 'segment_ids.shape must be a prefix of data.shape.*', ragged.segment_sum, rt, segment_ids, 3) # Otherwise, error is raised when we run the graph. segment_ids2 = ragged.from_row_splits( array_ops.placeholder_with_default(segment_ids.values, None), array_ops.placeholder_with_default(segment_ids.row_splits, None)) segmented2 = ragged.segment_sum(rt, segment_ids2, 3) with self.test_session(): self.assertRaisesRegexp( errors.InvalidArgumentError, 'segment_ids.shape must be a prefix of data.shape.*', segmented2.eval)
def testUnknownRankError(self): x = ragged.constant([[1, 2], [3]]) y = ragged.from_row_splits( array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits) with self.assertRaisesRegexp( ValueError, r'Unable to broadcast: unknown rank'): ragged.add(x, y)
def testExplicitAxisOptimizations(self): rt = ragged.from_row_splits(b'a b c d e f g'.split(), [0, 2, 5, 6, 6, 7]) with self.test_session(): self.assertEqual(ragged.bounding_shape(rt, 0).eval().tolist(), 5) self.assertEqual(ragged.bounding_shape(rt, 1).eval().tolist(), 3) self.assertEqual( ragged.bounding_shape(rt, [1, 0]).eval().tolist(), [3, 5])
def testUnknownRankError(self): x = ragged.constant([[1, 2], [3]]) y = ragged.from_row_splits( array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits) with self.assertRaisesRegexp( ValueError, r'Ragged elementwise ops require that rank \(number ' r'of dimensions\) be statically known.'): ragged.add(x, y)
def testRaggedBatchGatherUnknownRankError(self): params = [['a', 'b'], ['c', 'd']] indices = array_ops.placeholder(dtypes.int32, shape=None) ragged_indices = ragged.from_row_splits(indices, [0, 2, 4]) with self.assertRaisesRegexp( ValueError, 'batch_gather does not allow indices with unknown shape.'): ragged.batch_gather(params, indices) with self.assertRaisesRegexp( ValueError, 'batch_gather does not allow indices with unknown shape.'): ragged.batch_gather(params, ragged_indices)
def testKernelErrors(self): # An empty vector, defined using a placeholder to ensure that we can't # determine that it's invalid at graph-construction time. empty_vector = array_ops.placeholder_with_default(array_ops.zeros( [0], dtypes.int64), shape=None) bad_rt1 = ragged.from_row_splits(row_splits=[2, 3], values=[1, 2, 3]) with self.test_session(): bad_split0_error = r'First value of ragged splits must be 0.*' self.assertRaisesRegexp(errors.InvalidArgumentError, bad_split0_error, ragged.to_sparse(bad_rt1).eval) bad_rt2 = ragged.from_row_splits(row_splits=[0, 5], values=empty_vector) bad_rt3 = ragged.from_row_splits(row_splits=[0, 1], values=ragged.from_row_splits( row_splits=[0, 5], values=empty_vector)) with self.test_session(): split_mismatch1_error = r'Final value of ragged splits must match.*' for rt in [bad_rt2, bad_rt3]: self.assertRaisesRegexp(errors.InvalidArgumentError, split_mismatch1_error, ragged.to_sparse(rt).eval) bad_rt4 = ragged.from_row_splits(row_splits=[0, 5], values=ragged.from_row_splits( row_splits=[0], values=empty_vector)) with self.test_session(): split_mismatch2_error = r'Final value of ragged splits must match.*' self.assertRaisesRegexp(errors.InvalidArgumentError, split_mismatch2_error, ragged.to_sparse(bad_rt4).eval) bad_rt5 = ragged.from_row_splits(row_splits=empty_vector, values=[]) with self.test_session(): empty_splits_error = (r'ragged splits may not be empty.*') self.assertRaisesRegexp(errors.InvalidArgumentError, empty_splits_error, ragged.to_sparse(bad_rt5).eval)
def testKernelErrors(self): # An empty vector, defined using a placeholder to ensure that we can't # determine that it's invalid at graph-construction time. empty_vector = array_ops.placeholder_with_default( array_ops.zeros([0], dtypes.int64), shape=None) bad_rt1 = ragged.from_row_splits(row_splits=[2, 3], values=[1, 2, 3]) with self.test_session(): bad_split0_error = r'First value of ragged splits must be 0.*' self.assertRaisesRegexp(errors.InvalidArgumentError, bad_split0_error, ragged.to_sparse(bad_rt1).eval) bad_rt2 = ragged.from_row_splits(row_splits=[0, 5], values=empty_vector) bad_rt3 = ragged.from_row_splits( row_splits=[0, 1], values=ragged.from_row_splits(row_splits=[0, 5], values=empty_vector)) with self.test_session(): split_mismatch1_error = r'Final value of ragged splits must match.*' for rt in [bad_rt2, bad_rt3]: self.assertRaisesRegexp(errors.InvalidArgumentError, split_mismatch1_error, ragged.to_sparse(rt).eval) bad_rt4 = ragged.from_row_splits( row_splits=[0, 5], values=ragged.from_row_splits(row_splits=[0], values=empty_vector)) with self.test_session(): split_mismatch2_error = r'Final value of ragged splits must match.*' self.assertRaisesRegexp(errors.InvalidArgumentError, split_mismatch2_error, ragged.to_sparse(bad_rt4).eval) bad_rt5 = ragged.from_row_splits(row_splits=empty_vector, values=[]) with self.test_session(): empty_splits_error = (r'ragged splits may not be empty.*') self.assertRaisesRegexp(errors.InvalidArgumentError, empty_splits_error, ragged.to_sparse(bad_rt5).eval)
def testExplicitAxisOptimizations(self): rt = ragged.from_row_splits(b'a b c d e f g'.split(), [0, 2, 5, 6, 6, 7]) self.assertEqual(self.evaluate(ragged.bounding_shape(rt, 0)).tolist(), 5) self.assertEqual(self.evaluate(ragged.bounding_shape(rt, 1)).tolist(), 3) self.assertEqual( self.evaluate(ragged.bounding_shape(rt, [1, 0])).tolist(), [3, 5])