def testBadDefaultShape(self):
   with self.assertRaises((ValueError, errors.InvalidArgumentError)):
     pointer_ops.gather_with_default(
         params=[0, 1, 2, 3], indices=[0], default=[0])
   with self.assertRaises((ValueError, errors.InvalidArgumentError)):
     pointer_ops.gather_with_default(
         params=[[0, 1], [2, 3]], indices=[0], default=0)
 def testBadAxis(self):
   with self.assertRaises((ValueError, errors.InvalidArgumentError)):
     pointer_ops.gather_with_default(
         params=[0, 1, 2, 3], indices=[0], default=-1, axis=1)
   with self.assertRaises((ValueError, errors.InvalidArgumentError)):
     pointer_ops.gather_with_default(
         params=[[0, 1], [2, 3]], indices=[0], default=[0, 0], axis=2)
  def testIndexOutOfRange(self):
    # Note: because of the way gather_with_default is implemented, these
    # error messages will report values and ranges that are one greater than
    # those that were supplied to gather_with_default.
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 r'indices\[0\] = .* is not in .*'):
      self.evaluate(
          pointer_ops.gather_with_default(
              params=[0, 1, 2, 3], indices=[4], default=0))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 r'indices\[0\] = .* is not in .*'):
      self.evaluate(
          pointer_ops.gather_with_default(
              params=[0, 1, 2, 3], indices=[-2], default=0))
 def testAxisGreaterThan0_BehaviorMatchesTfGather(self):
   params = [['a1', 'a2', 'a3', 'a4'], ['b1', 'b2', 'b3', 'b4'],
             ['c1', 'c2', 'c3', 'c4']]
   indices = [2, 0, 2, 1]
   gathered = pointer_ops.gather_with_default(params, indices, '__', axis=1)
   expected = array_ops.gather(params, indices, axis=1)
   self.assertAllEqual(gathered, expected)
 def testScalarIndicesWith2DParams(self, params, indices, default):
   indices_t = constant_op.constant(indices, dtype=dtypes.int32)
   params_t = constant_op.constant(params)
   assert isinstance(indices, int)
   gathered = pointer_ops.gather_with_default(params_t, indices_t, default)
   expected = default if indices == -1 else params[indices]
   self.assertAllEqual(gathered, expected)
   # When there are no -1 indices, check that behavior matches tf.gather.
   if indices != -1:
     self.assertAllEqual(gathered, array_ops.gather(params_t, indices_t))
 def testAxisGreaterThan0(self):
   params = [['a0', 'a1', 'a2', 'a3', 'a4'],
             ['b0', 'b1', 'b2', 'b3', 'b4'],
             ['c0', 'c1', 'c2', 'c3', 'c4']]  # pyformat: disable
   indices = [2, 0, -1, 4, -1]
   gathered = pointer_ops.gather_with_default(params, indices, '__', axis=1)
   expected = [[b'a2', b'a0', b'__', b'a4', b'__'],
               [b'b2', b'b0', b'__', b'b4', b'__'],
               [b'c2', b'c0', b'__', b'c4', b'__']]  # pyformat: disable
   self.assertAllEqual(gathered, expected)
 def testVectorIndices(self, params, indices, default, expected_shape=None):
   indices_t = constant_op.constant(indices, dtype=dtypes.int32)
   params_t = constant_op.constant(params)
   gathered = pointer_ops.gather_with_default(params_t, indices_t, default)
   expected = [default if i == -1 else params[i] for i in indices]
   expected = constant_op.constant(expected, shape=expected_shape)
   self.assertAllEqual(gathered, expected)
   # When there are no -1 indices, check that behavior matches tf.gather.
   if not any(i == -1 for i in indices):
     self.assertAllEqual(gathered, array_ops.gather(params_t, indices_t))
  def testNegativeAxis(self):
    params_1d = _MakeTestTensor(shape=[3])
    params_2d = _MakeTestTensor(shape=[3, 3])
    params_3d = _MakeTestTensor(shape=[3, 3, 3])
    indices = [2, 0, -1, 1, -1]

    gathered1a = pointer_ops.gather_with_default(
        params_1d, indices, '__', axis=0)
    gathered1b = pointer_ops.gather_with_default(
        params_1d, indices, '__', axis=-1)
    expected1 = [b'v2', b'v0', b'__', b'v1', b'__']

    gathered2a = pointer_ops.gather_with_default(
        params_2d, indices, ['__', '__', '__'], axis=0)
    gathered2b = pointer_ops.gather_with_default(
        params_2d, indices, ['__', '__', '__'], axis=-2)
    expected2 = [[b'v20', b'v21', b'v22'],
                 [b'v00', b'v01', b'v02'],
                 [b'__', b'__', b'__'],
                 [b'v10', b'v11', b'v12'],
                 [b'__', b'__', b'__']]  # pyformat: disable

    gathered3a = pointer_ops.gather_with_default(
        params_2d, indices, '__', axis=1)
    gathered3b = pointer_ops.gather_with_default(
        params_2d, indices, '__', axis=-1)
    expected3 = [[b'v02', b'v00', b'__', b'v01', b'__'],
                 [b'v12', b'v10', b'__', b'v11', b'__'],
                 [b'v22', b'v20', b'__', b'v21', b'__']]  # pyformat: disable

    gathered4a = pointer_ops.gather_with_default(
        params_3d, indices, '__', axis=2)
    gathered4b = pointer_ops.gather_with_default(
        params_3d, indices, '__', axis=-1)
    expected4 = [[[b'v002', b'v000', b'__', b'v001', b'__'],
                  [b'v012', b'v010', b'__', b'v011', b'__'],
                  [b'v022', b'v020', b'__', b'v021', b'__']],
                 [[b'v102', b'v100', b'__', b'v101', b'__'],
                  [b'v112', b'v110', b'__', b'v111', b'__'],
                  [b'v122', b'v120', b'__', b'v121', b'__']],
                 [[b'v202', b'v200', b'__', b'v201', b'__'],
                  [b'v212', b'v210', b'__', b'v211', b'__'],
                  [b'v222', b'v220', b'__', b'v221', b'__']]]

    self.assertAllEqual(gathered1a, expected1)
    self.assertAllEqual(gathered1b, expected1)
    self.assertAllEqual(gathered2a, expected2)
    self.assertAllEqual(gathered2b, expected2)
    self.assertAllEqual(gathered3a, expected3)
    self.assertAllEqual(gathered3b, expected3)
    self.assertAllEqual(gathered4a, expected4)
    self.assertAllEqual(gathered4b, expected4)
Esempio n. 9
0
 def testDocStringExample(self):
     gathered = pointer_ops.gather_with_default(['a', 'b', 'c', 'd'],
                                                [2, 0, -1, 2, -1], '_')
     self.assertAllEqual(gathered, [b'c', b'a', b'_', b'c', b'_'])
Esempio n. 10
0
 def testBadDefaultDtype(self):
     with self.assertRaisesRegexp(TypeError,
                                  'Expected int32.*|Cannot convert .*'):
         pointer_ops.gather_with_default(params=[0, 1, 2, 3],
                                         indices=[0],
                                         default='a')
Esempio n. 11
0
    def testNegativeAxis(self):
        params_1d = _MakeTestTensor(shape=[3])
        params_2d = _MakeTestTensor(shape=[3, 3])
        params_3d = _MakeTestTensor(shape=[3, 3, 3])
        indices = [2, 0, -1, 1, -1]

        gathered1a = pointer_ops.gather_with_default(params_1d,
                                                     indices,
                                                     '__',
                                                     axis=0)
        gathered1b = pointer_ops.gather_with_default(params_1d,
                                                     indices,
                                                     '__',
                                                     axis=-1)
        expected1 = ['v2', 'v0', '__', 'v1', '__']

        gathered2a = pointer_ops.gather_with_default(params_2d,
                                                     indices,
                                                     ['__', '__', '__'],
                                                     axis=0)
        gathered2b = pointer_ops.gather_with_default(params_2d,
                                                     indices,
                                                     ['__', '__', '__'],
                                                     axis=-2)
        expected2 = [['v20', 'v21', 'v22'], ['v00', 'v01', 'v02'],
                     ['__', '__', '__'], ['v10', 'v11', 'v12'],
                     ['__', '__', '__']]  # pyformat: disable

        gathered3a = pointer_ops.gather_with_default(params_2d,
                                                     indices,
                                                     '__',
                                                     axis=1)
        gathered3b = pointer_ops.gather_with_default(params_2d,
                                                     indices,
                                                     '__',
                                                     axis=-1)
        expected3 = [['v02', 'v00', '__', 'v01', '__'],
                     ['v12', 'v10', '__', 'v11', '__'],
                     ['v22', 'v20', '__', 'v21', '__']]  # pyformat: disable

        gathered4a = pointer_ops.gather_with_default(params_3d,
                                                     indices,
                                                     '__',
                                                     axis=2)
        gathered4b = pointer_ops.gather_with_default(params_3d,
                                                     indices,
                                                     '__',
                                                     axis=-1)
        expected4 = [[[
            'v%s%s2' % (i, j),
            'v%s%s0' % (i, j), '__',
            'v%s%s1' % (i, j), '__'
        ] for j in range(3)] for i in range(3)]  # pyformat: disable

        self.assertAllEqual(gathered1a, expected1)
        self.assertAllEqual(gathered1b, expected1)
        self.assertAllEqual(gathered2a, expected2)
        self.assertAllEqual(gathered2b, expected2)
        self.assertAllEqual(gathered3a, expected3)
        self.assertAllEqual(gathered3b, expected3)
        self.assertAllEqual(gathered4a, expected4)
        self.assertAllEqual(gathered4b, expected4)