def testRaggedBatchGatherStaticError(self, params, indices, message=None, error=ValueError): with self.assertRaisesRegexp(error, message): ragged_batch_gather_ops.batch_gather(params, indices)
def testRaggedBatchGatherUnknownRankError(self): if context.executing_eagerly(): return params = [['a', 'b'], ['c', 'd']] indices = array_ops.placeholder(dtypes.int32, shape=None) ragged_indices = ragged_tensor.RaggedTensor.from_row_splits( indices, [0, 2, 4]) with self.assertRaisesRegexp( ValueError, 'batch_gather does not allow indices with unknown shape.'): ragged_batch_gather_ops.batch_gather(params, indices) with self.assertRaisesRegexp( ValueError, 'batch_gather does not allow indices with unknown shape.'): ragged_batch_gather_ops.batch_gather(params, ragged_indices)
def testRaggedBatchGatherUnknownRankError(self): if context.executing_eagerly(): return params = [['a', 'b'], ['c', 'd']] indices = array_ops.placeholder(dtypes.int32, shape=None) ragged_indices = ragged_tensor.RaggedTensor.from_row_splits( indices, [0, 2, 4]) with self.assertRaisesRegex( ValueError, r'batch_dims=-1 may only be negative ' r'if rank\(indices\) is statically known.'): ragged_batch_gather_ops.batch_gather(params, indices) with self.assertRaisesRegex( ValueError, r'batch_dims=-1 may only be negative ' r'if rank\(indices\) is statically known.'): ragged_batch_gather_ops.batch_gather(params, ragged_indices)
def testRaggedBatchGather(self, descr, params, indices, expected): result = ragged_batch_gather_ops.batch_gather(params, indices) self.assertAllEqual(result, expected)
def testRaggedBatchGather(self, descr, params, indices, expected): result = ragged_batch_gather_ops.batch_gather(params, indices) self.assertRaggedEqual(result, expected)