コード例 #1
0
 def testRaggedBatchGatherStaticError(self,
                                      params,
                                      indices,
                                      message,
                                      error=ValueError):
   with self.assertRaisesRegexp(error, message):
     ragged.batch_gather(params, indices)
コード例 #2
0
 def testRaggedBatchGatherStaticError(self,
                                      params,
                                      indices,
                                      message=None,
                                      error=ValueError):
     with self.assertRaisesRegexp(error, message):
         ragged.batch_gather(params, indices)
コード例 #3
0
  def testRaggedBatchGatherUnknownRankError(self):
    params = [['a', 'b'], ['c', 'd']]
    indices = array_ops.placeholder(dtypes.int32, shape=None)
    ragged_indices = ragged.from_row_splits(indices, [0, 2, 4])

    with self.assertRaisesRegexp(
        ValueError, 'batch_gather does not allow indices with unknown shape.'):
      ragged.batch_gather(params, indices)

    with self.assertRaisesRegexp(
        ValueError, 'batch_gather does not allow indices with unknown shape.'):
      ragged.batch_gather(params, ragged_indices)
コード例 #4
0
  def testRaggedBatchGatherUnknownRankError(self):
    params = [['a', 'b'], ['c', 'd']]
    indices = array_ops.placeholder(dtypes.int32, shape=None)
    ragged_indices = ragged.from_row_splits(indices, [0, 2, 4])

    with self.assertRaisesRegexp(
        ValueError, 'batch_gather does not allow indices with unknown shape.'):
      ragged.batch_gather(params, indices)

    with self.assertRaisesRegexp(
        ValueError, 'batch_gather does not allow indices with unknown shape.'):
      ragged.batch_gather(params, ragged_indices)
コード例 #5
0
 def testRaggedBatchGather(self, descr, params, indices, expected):
   result = ragged.batch_gather(params, indices)
   self.assertEqual(
       getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0))
   with self.test_session():
     if hasattr(expected, 'tolist'):
       expected = expected.tolist()
     self.assertEqual(result.eval().tolist(), expected)
コード例 #6
0
 def testRaggedBatchGather(self, descr, params, indices, expected):
     result = ragged.batch_gather(params, indices)
     self.assertEqual(getattr(result, 'ragged_rank', 0),
                      getattr(expected, 'ragged_rank', 0))
     with self.test_session():
         if hasattr(expected, 'tolist'):
             expected = expected.tolist()
         self.assertEqual(result.eval().tolist(), expected)
コード例 #7
0
 def testRaggedBatchGather(self, descr, params, indices, expected):
     result = ragged.batch_gather(params, indices)
     self.assertRaggedEqual(result, expected)