def testRaggedRankTwo(self): rt = ragged.constant([ [ [111, 112, 113, 114], [121], ], # row 0 [], # row 1 [[], [321, 322], [331]], # row 2 [[411, 412]] # row 3 ]) # pyformat: disable segment_ids1 = [0, 2, 2, 2] segmented1 = ragged.segment_sum(rt, segment_ids1, 3) expected1 = [ [[111, 112, 113, 114], [121]], # row 0 [], # row 1 [[411, 412], [321, 322], [331]] # row 2 ] # pyformat: disable self.assertRaggedEqual(segmented1, expected1) segment_ids2 = [1, 2, 1, 1] segmented2 = ragged.segment_sum(rt, segment_ids2, 3) expected2 = [[], [[111 + 411, 112 + 412, 113, 114], [121 + 321, 322], [331]], []] # pyformat: disable self.assertRaggedEqual(segmented2, expected2)
def testShapeMismatchError2(self): rt = ragged.constant([ [[111, 112, 113, 114], [121]], # row 0 [], # row 1 [[], [321, 322], [331]], # row 2 [[411, 412]] # row 3 ]) # pyformat: disable segment_ids = ragged.constant([[1, 2], [1], [1, 1, 2], [2]]) # Error is raised at graph-building time if we can detect it then. self.assertRaisesRegexp( errors.InvalidArgumentError, 'segment_ids.shape must be a prefix of data.shape.*', ragged.segment_sum, rt, segment_ids, 3) # Otherwise, error is raised when we run the graph. segment_ids2 = ragged.from_row_splits( array_ops.placeholder_with_default(segment_ids.values, None), array_ops.placeholder_with_default(segment_ids.row_splits, None)) segmented2 = ragged.segment_sum(rt, segment_ids2, 3) with self.test_session(): self.assertRaisesRegexp( errors.InvalidArgumentError, 'segment_ids.shape must be a prefix of data.shape.*', segmented2.eval)
def testRaggedRankTwo(self): rt = ragged.constant([ [[111, 112, 113, 114], [121],], # row 0 [], # row 1 [[], [321, 322], [331]], # row 2 [[411, 412]] # row 3 ]) # pyformat: disable segment_ids1 = [0, 2, 2, 2] segmented1 = ragged.segment_sum(rt, segment_ids1, 3) expected1 = [[[111, 112, 113, 114], [121]], # row 0 [], # row 1 [[411, 412], [321, 322], [331]] # row 2 ] # pyformat: disable self.assertEqual(self.evaluate(segmented1).tolist(), expected1) segment_ids2 = [1, 2, 1, 1] segmented2 = ragged.segment_sum(rt, segment_ids2, 3) expected2 = [[], [[111+411, 112+412, 113, 114], [121+321, 322], [331]], []] # pyformat: disable self.assertEqual(self.evaluate(segmented2).tolist(), expected2)
def testRaggedSegmentIds(self): rt = ragged.constant([ [[111, 112, 113, 114], [121],], # row 0 [], # row 1 [[], [321, 322], [331]], # row 2 [[411, 412]] # row 3 ]) # pyformat: disable segment_ids = ragged.constant([[1, 2], [], [1, 1, 2], [2]]) segmented = ragged.segment_sum(rt, segment_ids, 3) expected = [[], [111+321, 112+322, 113, 114], [121+331+411, 412]] # pyformat: disable self.assertEqual(self.evaluate(segmented).tolist(), expected)