def testSplitError(self, descr, pylist, row_lengths, num_or_size_splits, exception, message, axis=0, num=None): rt = ragged_tensor.RaggedTensor.from_row_lengths(pylist, row_lengths) with self.assertRaises(exception): result = ragged_array_ops.split(rt, num_or_size_splits, axis, num) self.evaluate(result)
def testSplitTensorDtype(self, dtype): rt = ragged_tensor.RaggedTensor.from_row_lengths([1.0, 2.0, 3.0, 4.0], [3, 1]) # split_lengths is a 1-D tensor split_lengths = ops.convert_to_tensor([1, 1], dtype=dtype) result = ragged_array_ops.split(rt, split_lengths) expected = [ ragged_tensor.RaggedTensor.from_row_lengths([1.0, 2.0, 3.0], [3]), ragged_tensor.RaggedTensor.from_row_lengths([4.0], [1]) ] self.assertLen(result, len(expected)) for res, exp in zip(result, expected): self.assertAllEqual(res, exp)
def testSplit(self, descr, pylist, row_lengths, num_or_size_splits, expected, axis=0, num=None, name=None): rt = ragged_tensor.RaggedTensor.from_row_lengths(pylist, row_lengths) result = ragged_array_ops.split(rt, num_or_size_splits, axis, num, name) self.assertLen(result, len(expected)) for res, exp in zip(result, expected): self.assertAllEqual(res, exp)
def split_tensors(rt, split_lengths): return ragged_array_ops.split(rt, split_lengths, axis=axis, num=num)
def split_tensors(rt): return ragged_array_ops.split(rt, 2)