def testBooleanScatterUpdate(self): var = variables.Variable([True, False]) update0 = state_ops.batch_scatter_update(var, [1], [True]) update1 = state_ops.batch_scatter_update( var, constant_op.constant([0], dtype=dtypes.int64), [False]) self.evaluate(variables.variables_initializer([var])) self.evaluate([update0, update1]) self.assertAllEqual([False, True], self.evaluate(var))
def testBooleanScatterUpdate(self): with self.test_session(use_gpu=False) as session: var = variables.Variable([True, False]) update0 = state_ops.batch_scatter_update(var, [1], [True]) update1 = state_ops.batch_scatter_update( var, constant_op.constant([0], dtype=dtypes.int64), [False]) var.initializer.run() session.run([update0, update1]) self.assertAllEqual([False, True], var.eval())
def testBooleanScatterUpdate(self): with self.test_session(use_gpu=False) as session: var = variables.Variable([True, False]) update0 = state_ops.batch_scatter_update(var, [1], [True]) update1 = state_ops.batch_scatter_update( var, constant_op.constant( [0], dtype=dtypes.int64), [False]) var.initializer.run() session.run([update0, update1]) self.assertAllEqual([False, True], var.eval())
def testScatterOutOfRange(self): params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) updates = np.array([-3, -4, -5]).astype(np.float32) ref = variables.Variable(params) self.evaluate(variables.variables_initializer([ref])) # Indices all in range, no problem. indices = np.array([2, 0, 5]) self.evaluate(state_ops.batch_scatter_update(ref, indices, updates)) # Test some out of range errors. indices = np.array([-1, 0, 5]) with self.assertRaisesOpError( r'indices\[0\] = \[-1\] does not index into shape \[6\]'): self.evaluate(state_ops.batch_scatter_update(ref, indices, updates)) indices = np.array([2, 0, 6]) with self.assertRaisesOpError(r'indices\[2\] = \[6\] does not index into ' r'shape \[6\]'): self.evaluate(state_ops.batch_scatter_update(ref, indices, updates))
def testScatterOutOfRange(self): params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) updates = np.array([-3, -4, -5]).astype(np.float32) with self.test_session(use_gpu=False): ref = variables.Variable(params) ref.initializer.run() # Indices all in range, no problem. indices = np.array([2, 0, 5]) state_ops.batch_scatter_update(ref, indices, updates).eval() # Test some out of range errors. indices = np.array([-1, 0, 5]) with self.assertRaisesOpError( r'indices\[0\] = \[-1\] does not index into shape \[6\]'): state_ops.batch_scatter_update(ref, indices, updates).eval() indices = np.array([2, 0, 6]) with self.assertRaisesOpError(r'indices\[2\] = \[6\] does not index into ' r'shape \[6\]'): state_ops.batch_scatter_update(ref, indices, updates).eval()