def testResVarInvalidOutputShape(self): res = variables.Variable( initial_value=lambda: array_ops.zeros(shape=[], dtype=dtypes.float32), dtype=dtypes.float32) with self.cached_session(): res.initializer.run() with self.assertRaisesOpError("Output must be at least 1-D"): state_ops.scatter_nd_update(res, [[0]], [0.22]).eval()
def testRank3InvalidShape2(self): indices = array_ops.zeros([2, 2, 1], dtypes.int32) updates = array_ops.zeros([2, 2], dtypes.int32) shape = np.array([2, 2, 2]) ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) with self.assertRaisesWithPredicateMatch( ValueError, "The inner \\d+ dimensions of input\\.shape="): state_ops.scatter_nd_update(ref, indices, updates)
def testRank3ValidShape(self): indices = array_ops.zeros([2, 2, 2], dtypes.int32) updates = array_ops.zeros([2, 2, 2], dtypes.int32) shape = np.array([2, 2, 2]) ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) self.assertAllEqual( state_ops.scatter_nd_update(ref, indices, updates).get_shape().as_list(), shape)
def testExtraIndicesDimensions(self): indices = array_ops.zeros([1, 1, 2], dtypes.int32) updates = array_ops.zeros([1, 1], dtypes.int32) shape = np.array([2, 2]) ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) scatter_update = state_ops.scatter_nd_update(ref, indices, updates) self.assertAllEqual(scatter_update.get_shape().as_list(), shape) expected_result = np.zeros([2, 2], dtype=np.int32) with self.cached_session(): ref.initializer.run() self.assertAllEqual(expected_result, scatter_update.eval())
def testSimple(self): indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) ref = variables.Variable([0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32) expected = np.array([0, 11, 0, 10, 9, 0, 0, 12]) scatter = state_ops.scatter_nd_update(ref, indices, updates) init = variables.global_variables_initializer() with self.session(use_gpu=True) as sess: sess.run(init) result = sess.run(scatter) self.assertAllClose(result, expected)
def testSimple3(self): indices = constant_op.constant([[1]], dtype=dtypes.int32) updates = constant_op.constant([[11., 12.]], dtype=dtypes.float32) ref = variables.Variable( [[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32) expected = np.array([[0., 0.], [11., 12.], [0., 0.]]) scatter = state_ops.scatter_nd_update(ref, indices, updates) init = variables.global_variables_initializer() with self.test_session(use_gpu=True) as sess: sess.run(init) result = sess.run(scatter) self.assertAllClose(result, expected)
def testSimple3(self): indices = constant_op.constant([[1]], dtype=dtypes.int32) updates = constant_op.constant([[11., 12.]], dtype=dtypes.float32) ref = variables.Variable([[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32) expected = np.array([[0., 0.], [11., 12.], [0., 0.]]) scatter = state_ops.scatter_nd_update(ref, indices, updates) init = variables.global_variables_initializer() with self.session(use_gpu=True) as sess: self.evaluate(init) result = self.evaluate(scatter) self.assertAllClose(result, expected)
def testSimpleResource(self): indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) for dtype in (dtypes.int32, dtypes.float32): updates = constant_op.constant([9, 10, 11, 12], dtype=dtype) ref = resource_variable_ops.ResourceVariable([0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype) expected = np.array([0, 11, 0, 10, 9, 0, 0, 12]) scatter = state_ops.scatter_nd_update(ref, indices, updates) with test_util.device(use_gpu=True): self.evaluate(ref.initializer) self.evaluate(scatter) self.assertAllClose(ref, expected)
def testSimpleResource(self): indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) ref = resource_variable_ops.ResourceVariable( [0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32) expected = np.array([0, 11, 0, 10, 9, 0, 0, 12]) scatter = state_ops.scatter_nd_update(ref, indices, updates) init = variables.global_variables_initializer() with self.test_session(use_gpu=True) as sess: sess.run(init) sess.run(scatter) self.assertAllClose(ref.eval(), expected)
def testSimple(self): indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128): updates = constant_op.constant([9, 10, 11, 12], dtype=dtype) ref = variables.Variable([0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype) expected = np.array([0, 11, 0, 10, 9, 0, 0, 12]) scatter = state_ops.scatter_nd_update(ref, indices, updates) init = variables.global_variables_initializer() with test_util.use_gpu(): self.evaluate(init) result = self.evaluate(scatter) self.assertAllClose(result, expected)
def testExtraIndicesDimensions(self): indices = array_ops.zeros([1, 1, 2], dtypes.int32) updates = array_ops.zeros([1, 1], dtypes.int32) shape = np.array([2, 2]) scatter = array_ops.scatter_nd(indices, updates, shape) self.assertAllEqual(scatter.get_shape().as_list(), shape) expected_result = np.zeros([2, 2], dtype=np.int32) with self.test_session(): self.assertAllEqual(expected_result, scatter.eval()) ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) scatter_update = state_ops.scatter_nd_update(ref, indices, updates) self.assertAllEqual(scatter_update.get_shape().as_list(), shape) with self.test_session(): ref.initializer.run() self.assertAllEqual(expected_result, scatter_update.eval())