예제 #1
0
 def testTensorParamsAndTensorIndices(self):
   params = ['a', 'b', 'c', 'd', 'e']
   indices = [2, 0, 2, 1]
   with self.test_session():
     self.assertEqual(
         ragged.gather(params, indices).eval().tolist(),
         [b'c', b'a', b'c', b'b'])
     self.assertEqual(type(ragged.gather(params, indices)), ops.Tensor)
 def testTensorParamsAndTensorIndices(self):
     params = ['a', 'b', 'c', 'd', 'e']
     indices = [2, 0, 2, 1]
     with self.test_session():
         self.assertEqual(
             ragged.gather(params, indices).eval().tolist(),
             [b'c', b'a', b'c', b'b'])
         self.assertEqual(type(ragged.gather(params, indices)), ops.Tensor)
예제 #3
0
 def testDocStringExamples(self):
     params = constant_op.constant(['a', 'b', 'c', 'd', 'e'])
     indices = constant_op.constant([3, 1, 2, 1, 0])
     ragged_params = ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
     ragged_indices = ragged.constant([[3, 1, 2], [1], [], [0]])
     self.assertRaggedEqual(ragged.gather(params, ragged_indices),
                            [[b'd', b'b', b'c'], [b'b'], [], [b'a']])
     self.assertRaggedEqual(
         ragged.gather(ragged_params, indices),
         [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']])
     self.assertRaggedEqual(
         ragged.gather(ragged_params, ragged_indices),
         [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]])
예제 #4
0
 def testOutOfBoundsError(self):
     tensor_params = ['a', 'b', 'c']
     tensor_indices = [0, 1, 2]
     ragged_params = ragged.constant([['a', 'b'], ['c']])
     ragged_indices = ragged.constant([[0, 3]])
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  r'indices\[1\] = 3 is not in \[0, 3\)'):
         self.evaluate(ragged.gather(tensor_params, ragged_indices))
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  r'indices\[2\] = 2 is not in \[0, 2\)'):
         self.evaluate(ragged.gather(ragged_params, tensor_indices))
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  r'indices\[1\] = 3 is not in \[0, 2\)'):
         self.evaluate(ragged.gather(ragged_params, ragged_indices))
 def testDocStringExamples(self):
     params = constant_op.constant(['a', 'b', 'c', 'd', 'e'])
     indices = constant_op.constant([3, 1, 2, 1, 0])
     ragged_params = ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
     ragged_indices = ragged.constant([[3, 1, 2], [1], [], [0]])
     with self.test_session():
         self.assertEqual(
             ragged.gather(params, ragged_indices).eval().tolist(),
             [[b'd', b'b', b'c'], [b'b'], [], [b'a']])
         self.assertEqual(
             ragged.gather(ragged_params, indices).eval().tolist(),
             [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']])
         self.assertEqual(
             ragged.gather(ragged_params, ragged_indices).eval().tolist(),
             [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]])
예제 #6
0
 def testDocStringExamples(self):
   params = constant_op.constant(['a', 'b', 'c', 'd', 'e'])
   indices = constant_op.constant([3, 1, 2, 1, 0])
   ragged_params = ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
   ragged_indices = ragged.constant([[3, 1, 2], [1], [], [0]])
   with self.test_session():
     self.assertEqual(
         ragged.gather(params, ragged_indices).eval().tolist(),
         [[b'd', b'b', b'c'], [b'b'], [], [b'a']])
     self.assertEqual(
         ragged.gather(ragged_params, indices).eval().tolist(),
         [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']])
     self.assertEqual(
         ragged.gather(ragged_params, ragged_indices).eval().tolist(),
         [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]])
예제 #7
0
 def testOutOfBoundsError(self):
   tensor_params = ['a', 'b', 'c']
   tensor_indices = [0, 1, 2]
   ragged_params = ragged.constant([['a', 'b'], ['c']])
   ragged_indices = ragged.constant([[0, 3]])
   with self.test_session():
     self.assertRaisesRegexp(errors.InvalidArgumentError,
                             r'indices\[1\] = 3 is not in \[0, 3\)',
                             ragged.gather(tensor_params, ragged_indices).eval)
     self.assertRaisesRegexp(errors.InvalidArgumentError,
                             r'indices\[2\] = 2 is not in \[0, 2\)',
                             ragged.gather(ragged_params, tensor_indices).eval)
     self.assertRaisesRegexp(errors.InvalidArgumentError,
                             r'indices\[1\] = 3 is not in \[0, 2\)',
                             ragged.gather(ragged_params, ragged_indices).eval)
예제 #8
0
 def testRaggedParamsAndTensorIndices(self):
   params = ragged.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']])
   indices = [2, 0, 2, 1]
   with self.test_session():
     self.assertEqual(
         ragged.gather(params, indices).eval().tolist(),
         [[b'f'], [b'a', b'b'], [b'f'], [b'c', b'd', b'e']])
예제 #9
0
 def testTensorParamsAndRaggedIndices(self):
   params = ['a', 'b', 'c', 'd', 'e']
   indices = ragged.constant([[2, 1], [1, 2, 0], [3]])
   with self.test_session():
     self.assertEqual(
         ragged.gather(params, indices).eval().tolist(),
         [[b'c', b'b'], [b'b', b'c', b'a'], [b'd']])
예제 #10
0
 def testRaggedParamsAndTensorIndices(self):
     params = ragged.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], [],
                               ['g']])
     indices = [2, 0, 2, 1]
     self.assertRaggedEqual(
         ragged.gather(params, indices),
         [[b'f'], [b'a', b'b'], [b'f'], [b'c', b'd', b'e']])
 def testTensorParamsAndRaggedIndices(self):
     params = ['a', 'b', 'c', 'd', 'e']
     indices = ragged.constant([[2, 1], [1, 2, 0], [3]])
     with self.test_session():
         self.assertEqual(
             ragged.gather(params, indices).eval().tolist(),
             [[b'c', b'b'], [b'b', b'c', b'a'], [b'd']])
 def testRaggedParamsAndScalarIndices(self):
     params = ragged.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], [],
                               ['g']])
     indices = 1
     with self.test_session():
         self.assertEqual(
             ragged.gather(params, indices).eval().tolist(),
             [b'c', b'd', b'e'])
예제 #13
0
 def testRaggedParamsAndRaggedIndices(self):
   params = ragged.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']])
   indices = ragged.constant([[2, 1], [1, 2, 0], [3]])
   with self.test_session():
     self.assertEqual(
         ragged.gather(params, indices).eval().tolist(),
         [[[b'f'], [b'c', b'd', b'e']],                # [[p[2], p[1]      ],
          [[b'c', b'd', b'e'], [b'f'], [b'a', b'b']],  #  [p[1], p[2], p[0]],
          [[]]]                                        #  [p[3]            ]]
     )  # pyformat: disable
예제 #14
0
 def test3DRaggedParamsAnd2DTensorIndices(self):
   params = ragged.constant([[['a', 'b'], []], [['c', 'd'], ['e'], ['f']],
                             [['g']]])
   indices = [[1, 2], [0, 1], [2, 2]]
   with self.test_session():
     self.assertEqual(
         ragged.gather(params, indices).eval().tolist(),
         [[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]],            # [[p1, p2],
          [[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]],  #  [p0, p1],
          [[[b'g']], [[b'g']]]]                                  #  [p2, p2]]
     )  # pyformat: disable
예제 #15
0
 def testTensorParamsAnd4DRaggedIndices(self):
     indices = ragged.constant(
         [[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]],
          [[[1, 0]]]],  # pyformat: disable
         ragged_rank=2,
         inner_shape=(2, ))
     params = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
     self.assertRaggedEqual(
         ragged.gather(params, indices),
         [[[[b'd', b'e'], [b'a', b'g']], []],
          [[[b'c', b'b'], [b'b', b'a']], [[b'c', b'f']], [[b'c', b'd']]],
          [[[b'b', b'a']]]])  # pyformat: disable
예제 #16
0
 def testRaggedParamsAndRaggedIndices(self):
     params = ragged.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], [],
                               ['g']])
     indices = ragged.constant([[2, 1], [1, 2, 0], [3]])
     self.assertRaggedEqual(
         ragged.gather(params, indices),
         [
             [[b'f'], [b'c', b'd', b'e']],  # [[p[2], p[1]      ],
             [[b'c', b'd', b'e'], [b'f'], [b'a', b'b']
              ],  #  [p[1], p[2], p[0]],
             [[]]
         ]  #  [p[3]            ]]
     )  # pyformat: disable
예제 #17
0
 def test3DRaggedParamsAnd2DTensorIndices(self):
     params = ragged.constant([[['a', 'b'], []], [['c', 'd'], ['e'], ['f']],
                               [['g']]])
     indices = [[1, 2], [0, 1], [2, 2]]
     self.assertRaggedEqual(
         ragged.gather(params, indices),
         [
             [[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]],  # [[p1, p2],
             [[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]
              ],  #  [p0, p1],
             [[[b'g']], [[b'g']]]
         ]  #  [p2, p2]]
     )  # pyformat: disable
예제 #18
0
 def testTensorParamsAnd4DRaggedIndices(self):
   indices = ragged.constant(
       [[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]],
        [[[1, 0]]]],  # pyformat: disable
       ragged_rank=2,
       inner_shape=(2,))
   params = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
   with self.test_session():
     self.assertEqual(
         ragged.gather(params, indices).eval().tolist(),
         [[[[b'd', b'e'], [b'a', b'g']], []],
          [[[b'c', b'b'], [b'b', b'a']], [[b'c', b'f']], [[b'c', b'd']]],
          [[[b'b', b'a']]]])  # pyformat: disable
예제 #19
0
 def testTensorParamsAndTensorIndices(self):
     params = ['a', 'b', 'c', 'd', 'e']
     indices = [2, 0, 2, 1]
     self.assertRaggedEqual(ragged.gather(params, indices),
                            [b'c', b'a', b'c', b'b'])
     self.assertIsInstance(ragged.gather(params, indices), ops.Tensor)
예제 #20
0
 def testTensorParamsAndRaggedIndices(self):
     params = ['a', 'b', 'c', 'd', 'e']
     indices = ragged.constant([[2, 1], [1, 2, 0], [3]])
     self.assertRaggedEqual(ragged.gather(params, indices),
                            [[b'c', b'b'], [b'b', b'c', b'a'], [b'd']])
예제 #21
0
 def testRaggedParamsAndScalarIndices(self):
     params = ragged.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], [],
                               ['g']])
     indices = 1
     self.assertRaggedEqual(ragged.gather(params, indices),
                            [b'c', b'd', b'e'])