コード例 #1
0
 def testSplitError(self,
                    descr,
                    pylist,
                    row_lengths,
                    num_or_size_splits,
                    exception,
                    message,
                    axis=0,
                    num=None):
     rt = ragged_tensor.RaggedTensor.from_row_lengths(pylist, row_lengths)
     with self.assertRaises(exception):
         result = ragged_array_ops.split(rt, num_or_size_splits, axis, num)
         self.evaluate(result)
コード例 #2
0
 def testSplitTensorDtype(self, dtype):
     rt = ragged_tensor.RaggedTensor.from_row_lengths([1.0, 2.0, 3.0, 4.0],
                                                      [3, 1])
     # split_lengths is a 1-D tensor
     split_lengths = ops.convert_to_tensor([1, 1], dtype=dtype)
     result = ragged_array_ops.split(rt, split_lengths)
     expected = [
         ragged_tensor.RaggedTensor.from_row_lengths([1.0, 2.0, 3.0], [3]),
         ragged_tensor.RaggedTensor.from_row_lengths([4.0], [1])
     ]
     self.assertLen(result, len(expected))
     for res, exp in zip(result, expected):
         self.assertAllEqual(res, exp)
コード例 #3
0
 def testSplit(self,
               descr,
               pylist,
               row_lengths,
               num_or_size_splits,
               expected,
               axis=0,
               num=None,
               name=None):
     rt = ragged_tensor.RaggedTensor.from_row_lengths(pylist, row_lengths)
     result = ragged_array_ops.split(rt, num_or_size_splits, axis, num,
                                     name)
     self.assertLen(result, len(expected))
     for res, exp in zip(result, expected):
         self.assertAllEqual(res, exp)
コード例 #4
0
 def split_tensors(rt, split_lengths):
     return ragged_array_ops.split(rt,
                                   split_lengths,
                                   axis=axis,
                                   num=num)
コード例 #5
0
 def split_tensors(rt):
     return ragged_array_ops.split(rt, 2)