def testScatterNdAddStateOps(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable( [1, 1, 1, 1, 1, 1, 1, 1], dtype=dtypes.float32, name="add") indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) expected = np.array([1, 12, 1, 11, 10, 1, 1, 13]) state_ops.scatter_nd_add(v, indices, updates) self.assertAllClose(expected, v.numpy())
def testScatterNdAddStateOps(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable( [1, 1, 1, 1, 1, 1, 1, 1], dtype=dtypes.float32, name="add") indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) expected = np.array([1, 12, 1, 11, 10, 1, 1, 13]) state_ops.scatter_nd_add(v, indices, updates) self.assertAllClose(expected, v.numpy())
def testConcurrentUpdates(self): num_updates = 10000 update_values = np.random.rand(num_updates) ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64) indices = constant_op.constant([[0, 1]] * num_updates, dtype=dtypes.int32) updates = constant_op.constant(update_values, dtype=dtypes.float64) expected_result = np.zeros([2, 2], dtype=np.float64) expected_result[0, 1] = np.sum(update_values) scatter = state_ops.scatter_nd_add(ref, indices, updates) init = variables.global_variables_initializer() self.evaluate(init) result = self.evaluate(scatter) assert np.allclose(result, expected_result)
def testConcurrentUpdates(self): num_updates = 10000 update_values = np.random.rand(num_updates) ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64) indices = constant_op.constant([[0, 1]] * num_updates, dtype=dtypes.int32) updates = constant_op.constant(update_values, dtype=dtypes.float64) expected_result = np.zeros([2, 2], dtype=np.float64) expected_result[0, 1] = np.sum(update_values) scatter = state_ops.scatter_nd_add(ref, indices, updates) init = variables.global_variables_initializer() with session.Session() as sess: sess.run(init) result = sess.run(scatter) assert np.allclose(result, expected_result)