예제 #1
0
    def testBooleanScatterUpdate(self):
        var = variables.Variable([True, False])
        update0 = state_ops.batch_scatter_update(var, [1], [True])
        update1 = state_ops.batch_scatter_update(
            var, constant_op.constant([0], dtype=dtypes.int64), [False])
        self.evaluate(variables.variables_initializer([var]))

        self.evaluate([update0, update1])

        self.assertAllEqual([False, True], self.evaluate(var))
예제 #2
0
    def testBooleanScatterUpdate(self):
        with self.test_session(use_gpu=False) as session:
            var = variables.Variable([True, False])
            update0 = state_ops.batch_scatter_update(var, [1], [True])
            update1 = state_ops.batch_scatter_update(
                var, constant_op.constant([0], dtype=dtypes.int64), [False])
            var.initializer.run()

            session.run([update0, update1])

            self.assertAllEqual([False, True], var.eval())
  def testBooleanScatterUpdate(self):
    with self.test_session(use_gpu=False) as session:
      var = variables.Variable([True, False])
      update0 = state_ops.batch_scatter_update(var, [1], [True])
      update1 = state_ops.batch_scatter_update(
          var, constant_op.constant(
              [0], dtype=dtypes.int64), [False])
      var.initializer.run()

      session.run([update0, update1])

      self.assertAllEqual([False, True], var.eval())
  def testScatterOutOfRange(self):
    params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
    updates = np.array([-3, -4, -5]).astype(np.float32)

    ref = variables.Variable(params)
    self.evaluate(variables.variables_initializer([ref]))

    # Indices all in range, no problem.
    indices = np.array([2, 0, 5])
    self.evaluate(state_ops.batch_scatter_update(ref, indices, updates))

    # Test some out of range errors.
    indices = np.array([-1, 0, 5])
    with self.assertRaisesOpError(
        r'indices\[0\] = \[-1\] does not index into shape \[6\]'):
      self.evaluate(state_ops.batch_scatter_update(ref, indices, updates))

    indices = np.array([2, 0, 6])
    with self.assertRaisesOpError(r'indices\[2\] = \[6\] does not index into '
                                  r'shape \[6\]'):
      self.evaluate(state_ops.batch_scatter_update(ref, indices, updates))
  def testScatterOutOfRange(self):
    params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
    updates = np.array([-3, -4, -5]).astype(np.float32)
    with self.test_session(use_gpu=False):
      ref = variables.Variable(params)
      ref.initializer.run()

      # Indices all in range, no problem.
      indices = np.array([2, 0, 5])
      state_ops.batch_scatter_update(ref, indices, updates).eval()

      # Test some out of range errors.
      indices = np.array([-1, 0, 5])
      with self.assertRaisesOpError(
          r'indices\[0\] = \[-1\] does not index into shape \[6\]'):
        state_ops.batch_scatter_update(ref, indices, updates).eval()

      indices = np.array([2, 0, 6])
      with self.assertRaisesOpError(r'indices\[2\] = \[6\] does not index into '
                                    r'shape \[6\]'):
        state_ops.batch_scatter_update(ref, indices, updates).eval()