def testRaggedBatchGatherStaticError(self, params, indices, message, error=ValueError): with self.assertRaisesRegexp(error, message): ragged.batch_gather(params, indices)
def testRaggedBatchGatherStaticError(self, params, indices, message=None, error=ValueError): with self.assertRaisesRegexp(error, message): ragged.batch_gather(params, indices)
def testRaggedBatchGatherUnknownRankError(self): params = [['a', 'b'], ['c', 'd']] indices = array_ops.placeholder(dtypes.int32, shape=None) ragged_indices = ragged.from_row_splits(indices, [0, 2, 4]) with self.assertRaisesRegexp( ValueError, 'batch_gather does not allow indices with unknown shape.'): ragged.batch_gather(params, indices) with self.assertRaisesRegexp( ValueError, 'batch_gather does not allow indices with unknown shape.'): ragged.batch_gather(params, ragged_indices)
def testRaggedBatchGather(self, descr, params, indices, expected): result = ragged.batch_gather(params, indices) self.assertEqual( getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0)) with self.test_session(): if hasattr(expected, 'tolist'): expected = expected.tolist() self.assertEqual(result.eval().tolist(), expected)
def testRaggedBatchGather(self, descr, params, indices, expected): result = ragged.batch_gather(params, indices) self.assertEqual(getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0)) with self.test_session(): if hasattr(expected, 'tolist'): expected = expected.tolist() self.assertEqual(result.eval().tolist(), expected)
def testRaggedBatchGather(self, descr, params, indices, expected): result = ragged.batch_gather(params, indices) self.assertRaggedEqual(result, expected)