def GetOptimizer(self, lr): params = self.params return distributed_shampoo.DistributedShampoo( learning_rate=lr, momentum=params.momentum, start_preconditioning_steps=params.start_preconditioning_steps, initial_accumulator_value=params.initial_accumulator_value, matrix_epsilon=params.matrix_epsilon, statistics_computation_frequency=( params.statistics_computation_frequency), second_moment_averaging=params.second_moment_averaging, max_any_dim=params.max_any_dim, block_size=params.block_size, global_step=self.theta.global_step)
def testShampooWithMatrixShapedTensors(self): # Parameter matrix of size [4,2] would result in L_{t}, and R_{t} of # sizes [4, 4] and [2, 2] size = [4, 2] init_var_np = np.zeros(size) # Initialize gradient as random tensor. grad_np = np.random.rand(size[0], size[1]) with tf.Session() as sess: global_step = tf.Variable(0, dtype=tf.int64) var = tf.Variable(init_var_np, dtype=tf.float32) grad = tf.constant(grad_np, dtype=tf.float32) opt = distributed_shampoo.DistributedShampoo( learning_rate=1.0, momentum=0.0, start_preconditioning_steps=0, synchronous_preconditioning=True, global_step=global_step) # Run a single step of gradient update. update = opt.apply_gradients(zip([grad], [var]), global_step=global_step) # Preconditioner computation and assignments to variables. compute_preconditioner_op = opt.invoke_async_preconditioner_computation( tf.cast(global_step, tf.int32)) assign_preconditioners_to_vars_op = ( opt.assign_preconditioner_to_host_vars()) self.evaluate(tf.global_variables_initializer()) tf.tables_initializer().run() init_val = sess.run(var) self.assertAllCloseAccordingToType(init_var_np, init_val) def np_power(mat_g, alpha, matrix_epsilon=1e-6): """Computes mat_g^alpha for a square symmetric matrix mat_g.""" mat_for_svd = mat_g + np.eye(mat_g.shape[0]) * matrix_epsilon mat_u, diag_d, mat_v = np.linalg.svd(mat_for_svd, full_matrices=True) diag_d = np.power(np.maximum(diag_d, matrix_epsilon), alpha) return np.dot(mat_u, np.dot(np.diag(diag_d), mat_v)) def norm(val): return np.sqrt(np.sum(np.square(val))) # Run a step of preconditioner update. update.run() mat_g1 = np.dot(grad_np, grad_np.transpose()) expected_mat_g1 = sess.run(opt.get_slot(var, 'mat_statistics_0')) self.assertAllCloseAccordingToType(mat_g1, expected_mat_g1, atol=1e-1) mat_g2 = np.dot(grad_np.transpose(), grad_np) expected_mat_g2 = sess.run(opt.get_slot(var, 'mat_statistics_1')) self.assertAllCloseAccordingToType(mat_g2, expected_mat_g2, atol=1e-1) compute_preconditioner_op.run() assign_preconditioners_to_vars_op.run() mat_left = np_power(mat_g1, -0.25) expected_mat_left = sess.run( opt.get_slot(var, 'mat_preconditioner_0')) self.assertAllCloseAccordingToType(mat_left, expected_mat_left, atol=1e-1) mat_right = np_power(mat_g2, -0.25) expected_mat_right = sess.run( opt.get_slot(var, 'mat_preconditioner_1')) self.assertAllCloseAccordingToType(mat_right, expected_mat_right, atol=1e-1) # As the preconditioners are initialized to all zero. We don't make # any update. var_step_0_val = sess.run(var) self.assertAllCloseAccordingToType(init_var_np, var_step_0_val, atol=1e-1) # Run another step of training. update.run() var_step_1_val = sess.run(var) # New update has the scale of the second diagonal adagrad update. adagrad_update = grad_np / np.sqrt(2 * np.square(grad_np)) preconditioned_grad_update = np.dot(np.dot(mat_left, grad_np), mat_right) # With normalization by diagonal enabled. var_step_1_np = init_var_np - preconditioned_grad_update * norm( adagrad_update) / norm(preconditioned_grad_update) self.assertAllCloseAccordingToType(var_step_1_np, var_step_1_val, atol=1e-1) # Compute new preconditioners. compute_preconditioner_op.run() assign_preconditioners_to_vars_op.run() # Gradients are summed over time. mat_g1 += np.dot(grad_np, grad_np.transpose()) mat_left = np_power(mat_g1, -0.25) expected_mat_left = sess.run( opt.get_slot(var, 'mat_preconditioner_0')) self.assertAllCloseAccordingToType(mat_left, expected_mat_left, atol=1e-1) mat_g2 += np.dot(grad_np.transpose(), grad_np) mat_right = np_power(mat_g2, -0.25) expected_mat_right = sess.run( opt.get_slot(var, 'mat_preconditioner_1')) self.assertAllCloseAccordingToType(mat_right, expected_mat_right, atol=1e-1)