def testRaggedGatherNdStaticError(self,
                                   params,
                                   indices,
                                   message=None,
                                   error=ValueError):
   with self.assertRaisesRegex(error, message):
     ragged_gather_ops.gather_nd(params, indices)
  def testRaggedGatherNdUnknownRankError(self):
    if context.executing_eagerly():
      return
    params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd']])
    indices1 = array_ops.placeholder(dtypes.int32, shape=None)
    indices2 = array_ops.placeholder(dtypes.int32, shape=[None])

    with self.assertRaisesRegex(ValueError,
                                'indices.rank be statically known.'):
      ragged_gather_ops.gather_nd(params, indices1)
    with self.assertRaisesRegex(
        ValueError, r'indices.shape\[-1\] must be statically known.'):
      ragged_gather_ops.gather_nd(params, indices2)
Ejemplo n.º 3
0
def _ragged_gather_nd_v1(params, indices, name=None, batch_dims=0):
    return ragged_gather_ops.gather_nd(params=params,
                                       indices=indices,
                                       batch_dims=batch_dims,
                                       name=name)
 def testRaggedGatherNd(self, descr, params, indices, expected):
   result = ragged_gather_ops.gather_nd(params, indices)
   self.assertAllEqual(result, expected)
Ejemplo n.º 5
0
def _ragged_gather_nd_v1(params, indices, name=None, batch_dims=0):
  return ragged_gather_ops.gather_nd(
      params=params,
      indices=indices,
      batch_dims=batch_dims,
      name=name)