def corrections_func( mainPN, batch_size, trace_length, corrections=False, cube=None, clip_lola_update_norm=False, lola_correction_multiplier=1.0, clip_lola_correction_norm=False, clip_lola_actor_norm=False, against_destabilizer_exploiter=False, ): """Computes corrections for policy gradients. Args: ----- mainPN: list of policy/Q-networks batch_size: int trace_length: int corrections: bool (default: False) Whether policy networks should use corrections. cube: tf.Varialbe or None (default: None) If provided, should be constructed via `lola.utils.make_cube`. Used for variance reduction of the value estimation. When provided, the computation graph for corrections is faster to compile but is quite memory inefficient. When None, variance reduction graph is contructed dynamically, is a little longer to compile, but has lower memory footprint. """ # not mem_efficient if cube is not None: ac_logp0 = tf.reshape(mainPN[0].log_pi_action_bs_t, [batch_size, 1, trace_length]) ac_logp1 = tf.reshape(mainPN[1].log_pi_action_bs_t, [batch_size, trace_length, 1]) mat_1 = tf.reshape( tf.squeeze(tf.matmul(ac_logp1, ac_logp0)), [batch_size, 1, trace_length * trace_length], ) v_0 = tf.matmul( tf.reshape(mainPN[0].sample_reward, [batch_size, trace_length, 1]), mat_1, ) v_0 = tf.reshape( v_0, [batch_size, trace_length, trace_length, trace_length]) v_1 = tf.matmul( tf.reshape(mainPN[1].sample_reward, [batch_size, trace_length, 1]), mat_1, ) v_1 = tf.reshape( v_1, [batch_size, trace_length, trace_length, trace_length]) v_0 = 2 * tf.reduce_sum(v_0 * cube) / batch_size v_1 = 2 * tf.reduce_sum(v_1 * cube) / batch_size # wt mem_efficient else: ac_logp0 = tf.reshape(mainPN[0].log_pi_action_bs_t, [batch_size, trace_length]) ac_logp1 = tf.reshape(mainPN[1].log_pi_action_bs_t, [batch_size, trace_length]) # Static exclusive cumsum ac_logp0_cumsum = [tf.constant(0.0)] ac_logp1_cumsum = [tf.constant(0.0)] for i in range(trace_length - 1): ac_logp0_cumsum.append(tf.add(ac_logp0_cumsum[-1], ac_logp0[:, i])) ac_logp1_cumsum.append(tf.add(ac_logp1_cumsum[-1], ac_logp1[:, i])) # Compute v_0 and v_1 mat_cumsum = ac_logp0[:, 0] * ac_logp1[:, 0] v_0 = mat_cumsum * mainPN[0].sample_reward[:, 0] v_1 = mat_cumsum * mainPN[1].sample_reward[:, 0] for i in range(1, trace_length): mat_cumsum = tf.add(mat_cumsum, ac_logp0[:, i] * ac_logp1[:, i]) mat_cumsum = tf.add(mat_cumsum, ac_logp0_cumsum[i] * ac_logp1[:, i]) mat_cumsum = tf.add(mat_cumsum, ac_logp1_cumsum[i] * ac_logp0[:, i]) v_0 = tf.add(v_0, mat_cumsum * mainPN[0].sample_reward[:, i]) v_1 = tf.add(v_1, mat_cumsum * mainPN[1].sample_reward[:, i]) v_0 = 2 * tf.reduce_sum(v_0) / batch_size if against_destabilizer_exploiter: v_1 = 2 * v_1 / batch_size else: v_1 = 2 * tf.reduce_sum(v_1) / batch_size mainPN[0].v_0_log = v_0 mainPN[1].v_1_log = v_1 actor_target_error_0 = mainPN[0].target - tf.stop_gradient(mainPN[0].value) v_0_pi_0 = (2 * tf.reduce_sum( (actor_target_error_0 * mainPN[0].gamma_array) * mainPN[0].log_pi_action_bs_t) / batch_size) v_0_pi_1 = (2 * tf.reduce_sum( (actor_target_error_0 * mainPN[1].gamma_array) * mainPN[1].log_pi_action_bs_t) / batch_size) actor_target_error_1 = mainPN[1].target - tf.stop_gradient(mainPN[1].value) v_1_pi_0 = (2 * tf.reduce_sum( (actor_target_error_1 * mainPN[0].gamma_array) * mainPN[0].log_pi_action_bs_t) / batch_size) v_1_pi_1 = (2 * tf.reduce_sum( (actor_target_error_1 * mainPN[1].gamma_array) * mainPN[1].log_pi_action_bs_t) / batch_size) mainPN[0].actor_target_error = actor_target_error_0 mainPN[1].actor_target_error = actor_target_error_1 mainPN[0].actor_loss = v_0_pi_0 mainPN[1].actor_loss = v_1_pi_1 mainPN[0].value_used_for_correction = v_0 mainPN[1].value_used_for_correction = v_1 v_0_grad_theta_0 = flatgrad(v_0_pi_0, mainPN[0].parameters) v_0_grad_theta_1 = flatgrad(v_0_pi_1, mainPN[1].parameters) v_1_grad_theta_0 = flatgrad(v_1_pi_0, mainPN[0].parameters) v_1_grad_theta_1 = flatgrad(v_1_pi_1, mainPN[1].parameters) mainPN[0].grad = v_0_grad_theta_0 mainPN[1].grad = v_1_grad_theta_1 mainPN[0].grad_sum = tf.math.reduce_sum(v_0_grad_theta_0) mainPN[1].grad_sum = tf.math.reduce_sum(v_1_grad_theta_1) mainPN[0].grad_v_1 = v_1_grad_theta_0 mainPN[1].grad_v_0 = v_0_grad_theta_1 if corrections: v_0_grad_theta_0_wrong = flatgrad(v_0, mainPN[0].parameters) if against_destabilizer_exploiter: # v_1_grad_theta_1_wrong_splits = [ flatgrad(v_1[i], mainPN[1].parameters) for i in range(batch_size)] # v_1_grad_theta_1_wrong = tf.stack(v_1_grad_theta_1_wrong_splits, axis=1) v_1_grad_theta_1_wrong = tf.vectorized_map( partial(flatgrad, var_list=mainPN[1].parameters), v_1) else: v_1_grad_theta_1_wrong = flatgrad(v_1, mainPN[1].parameters) param_len = v_0_grad_theta_0_wrong.get_shape()[0].value # param_len = -1 if against_destabilizer_exploiter: multiply0 = tf.matmul( tf.reshape(tf.stop_gradient(v_0_grad_theta_1), [1, param_len]), tf.reshape(v_1_grad_theta_1_wrong, [param_len, batch_size]), ) else: multiply0 = tf.matmul( tf.reshape(tf.stop_gradient(v_0_grad_theta_1), [1, param_len]), tf.reshape(v_1_grad_theta_1_wrong, [param_len, 1]), ) multiply1 = tf.matmul( tf.reshape(tf.stop_gradient(v_1_grad_theta_0), [1, param_len]), tf.reshape(v_0_grad_theta_0_wrong, [param_len, 1]), ) if against_destabilizer_exploiter: second_order0 = flatgrad(multiply0, mainPN[0].parameters) second_order0 = second_order0[:, None] # second_order0_splits = [flatgrad(multiply0[:, i], mainPN[0].parameters) for i in range(batch_size)] # second_order0 = tf.stack(second_order0_splits, axis=1) # second_order0 = tf.vectorized_map(partial(flatgrad, var_list=mainPN[0].parameters), multiply0[0, :]) # second_order0 = tf.reshape(second_order0, [param_len, batch_size]) else: second_order0 = flatgrad(multiply0, mainPN[0].parameters) second_order1 = flatgrad(multiply1, mainPN[1].parameters) mainPN[0].multiply0 = multiply0 mainPN[0].v_0_grad_01 = second_order0 mainPN[1].v_1_grad_10 = second_order1 mainPN[0].second_order = tf.math.reduce_sum(second_order0) mainPN[1].second_order = tf.math.reduce_sum(second_order1) if against_destabilizer_exploiter: second_order0 = tf.math.reduce_sum(second_order0, axis=1) second_order0 = second_order0 * lola_correction_multiplier second_order1 = second_order1 * lola_correction_multiplier if clip_lola_correction_norm: second_order0 = tf.clip_by_norm(second_order0, clip_lola_correction_norm, axes=None, name=None) second_order1 = tf.clip_by_norm(second_order1, clip_lola_correction_norm, axes=None, name=None) if clip_lola_actor_norm: v_0_grad_theta_0 = tf.clip_by_norm(v_0_grad_theta_0, clip_lola_actor_norm, axes=None, name=None) v_1_grad_theta_1 = tf.clip_by_norm(v_1_grad_theta_1, clip_lola_actor_norm, axes=None, name=None) delta_0 = v_0_grad_theta_0 + second_order0 delta_1 = v_1_grad_theta_1 + second_order1 if clip_lola_update_norm: delta_0 = tf.clip_by_norm(delta_0, clip_lola_update_norm, axes=None, name=None) delta_1 = tf.clip_by_norm(delta_1, clip_lola_update_norm, axes=None, name=None) mainPN[0].delta = delta_0 mainPN[1].delta = delta_1 else: mainPN[0].delta = v_0_grad_theta_0 mainPN[1].delta = v_1_grad_theta_1 # To prevent some logic about logging stuff mainPN[0].v_0_grad_01 = tf.reduce_sum(v_0_grad_theta_0) * 0.0 mainPN[1].v_1_grad_10 = tf.reduce_sum(v_0_grad_theta_0) * 0.0
def simple_actor_training_func(policy_network, opp_policy_network, batch_size, trace_length, cube=None): # not mem_efficient if cube is not None: ac_logp0 = tf.reshape(policy_network.log_pi_action_bs_t, [batch_size, 1, trace_length]) ac_logp1 = tf.reshape( opp_policy_network.log_pi_action_bs_t, [batch_size, trace_length, 1], ) mat_1 = tf.reshape( tf.squeeze(tf.matmul(ac_logp1, ac_logp0)), [batch_size, 1, trace_length * trace_length], ) v_0 = tf.matmul( tf.reshape(policy_network.sample_reward, [batch_size, trace_length, 1]), mat_1, ) v_0 = tf.reshape( v_0, [batch_size, trace_length, trace_length, trace_length]) v_1 = tf.matmul( tf.reshape(opp_policy_network.sample_reward, [batch_size, trace_length, 1]), mat_1, ) v_1 = tf.reshape( v_1, [batch_size, trace_length, trace_length, trace_length]) v_0 = 2 * tf.reduce_sum(v_0 * cube) / batch_size v_1 = 2 * tf.reduce_sum(v_1 * cube) / batch_size # wt mem_efficient else: ac_logp0 = tf.reshape(policy_network.log_pi_action_bs_t, [batch_size, trace_length]) ac_logp1 = tf.reshape(opp_policy_network.log_pi_action_bs_t, [batch_size, trace_length]) # Static exclusive cumsum ac_logp0_cumsum = [tf.constant(0.0)] ac_logp1_cumsum = [tf.constant(0.0)] for i in range(trace_length - 1): ac_logp0_cumsum.append(tf.add(ac_logp0_cumsum[-1], ac_logp0[:, i])) ac_logp1_cumsum.append(tf.add(ac_logp1_cumsum[-1], ac_logp1[:, i])) # Compute v_0 and v_1 mat_cumsum = ac_logp0[:, 0] * ac_logp1[:, 0] v_0 = mat_cumsum * policy_network.sample_reward[:, 0] v_1 = mat_cumsum * opp_policy_network.sample_reward[:, 0] for i in range(1, trace_length): mat_cumsum = tf.add(mat_cumsum, ac_logp0[:, i] * ac_logp1[:, i]) mat_cumsum = tf.add(mat_cumsum, ac_logp0_cumsum[i] * ac_logp1[:, i]) mat_cumsum = tf.add(mat_cumsum, ac_logp1_cumsum[i] * ac_logp0[:, i]) v_0 = tf.add(v_0, mat_cumsum * policy_network.sample_reward[:, i]) v_1 = tf.add(v_1, mat_cumsum * opp_policy_network.sample_reward[:, i]) v_0 = 2 * tf.reduce_sum(v_0) / batch_size v_1 = 2 * tf.reduce_sum(v_1) / batch_size policy_network.v_0_log = v_0 actor_target_error_0 = policy_network.target - tf.stop_gradient( policy_network.value) v_0_pi_0 = (2 * tf.reduce_sum( (actor_target_error_0 * policy_network.gamma_array) * policy_network.log_pi_action_bs_t) / batch_size) # v_1_pi_0 = 2*tf.reduce_sum((actor_target_error_1 * policy_network.gamma_array) * policy_network.log_pi_action_bs_t) / batch_size policy_network.actor_target_error = actor_target_error_0 policy_network.actor_loss = v_0_pi_0 policy_network.value_used_for_correction = v_0 v_0_grad_theta_0 = flatgrad(v_0_pi_0, policy_network.parameters) # v_1_grad_theta_0 = flatgrad(v_1_pi_0, policy_network.parameters) policy_network.grad = v_0_grad_theta_0 policy_network.grad_sum = tf.math.reduce_sum(v_0_grad_theta_0) # policy_network.grad_v_1 = v_1_grad_theta_0 policy_network.delta = v_0_grad_theta_0 # To prevent some logic about logging stuff policy_network.v_0_grad_01 = tf.reduce_sum(v_0_grad_theta_0) * 0.0
def corrections_func(mainQN, corrections, gamma, pseudo, reg): mainQN[0].lr_correction = tf.placeholder(shape=[1], dtype=tf.float32) mainQN[1].lr_correction = tf.placeholder(shape=[1], dtype=tf.float32) theta_1_all = mainQN[0].p_act theta_2_all = mainQN[1].p_act theta_1 = tf.slice(theta_1_all, [0, 0], [4, 1]) theta_2 = tf.slice(theta_2_all, [0, 0], [4, 1]) theta_1_0 = tf.slice(theta_1_all, [4, 0], [1, 1]) theta_2_0 = tf.slice(theta_2_all, [4, 0], [1, 1]) p_1 = tf.nn.sigmoid(theta_1) p_2 = tf.nn.sigmoid(theta_2) mainQN[0].policy = tf.nn.sigmoid(theta_1_all) mainQN[1].policy = tf.nn.sigmoid(theta_2_all) p_1_0 = tf.nn.sigmoid(theta_1_0) p_2_0 = tf.nn.sigmoid(theta_2_0) p_1_0_v = tf.concat([p_1_0, (1 - p_1_0)], 0) p_2_0_v = tf.concat([p_2_0, (1 - p_2_0)], 0) s_0 = tf.reshape(tf.matmul(p_1_0_v, tf.transpose(p_2_0_v)), [-1, 1]) # CC, CD, DC, DD P = tf.concat( [ tf.multiply(p_1, p_2), tf.multiply(p_1, 1 - p_2), tf.multiply(1 - p_1, p_2), tf.multiply(1 - p_1, 1 - p_2), ], 1, ) R_1 = tf.placeholder(shape=[4, 1], dtype=tf.float32) R_2 = tf.placeholder(shape=[4, 1], dtype=tf.float32) I_m_P = tf.diag([1.0, 1.0, 1.0, 1.0]) - P * gamma v_0 = tf.matmul(tf.matmul(tf.matrix_inverse(I_m_P), R_1), s_0, transpose_a=True) v_1 = tf.matmul(tf.matmul(tf.matrix_inverse(I_m_P), R_2), s_0, transpose_a=True) if reg > 0: for indx, _ in enumerate(mainQN[0].parameters): v_0 -= reg * tf.reduce_sum( tf.nn.l2_loss(tf.square(mainQN[0].parameters[indx]))) v_1 -= reg * tf.reduce_sum( tf.nn.l2_loss(tf.square(mainQN[1].parameters[indx]))) v_0_grad_theta_0 = flatgrad(v_0, mainQN[0].parameters) v_0_grad_theta_1 = flatgrad(v_0, mainQN[1].parameters) v_1_grad_theta_0 = flatgrad(v_1, mainQN[0].parameters) v_1_grad_theta_1 = flatgrad(v_1, mainQN[1].parameters) v_0_grad_theta_0_wrong = flatgrad(v_0, mainQN[0].parameters) v_1_grad_theta_1_wrong = flatgrad(v_1, mainQN[1].parameters) param_len = v_0_grad_theta_0_wrong.get_shape()[0].value if pseudo: multiply0 = tf.matmul( tf.reshape(v_0_grad_theta_1, [1, param_len]), tf.reshape(v_1_grad_theta_1, [param_len, 1]), ) multiply1 = tf.matmul( tf.reshape(v_1_grad_theta_0, [1, param_len]), tf.reshape(v_0_grad_theta_0, [param_len, 1]), ) else: multiply0 = tf.matmul( tf.reshape(tf.stop_gradient(v_0_grad_theta_1), [1, param_len]), tf.reshape(v_1_grad_theta_1_wrong, [param_len, 1]), ) multiply1 = tf.matmul( tf.reshape(tf.stop_gradient(v_1_grad_theta_0), [1, param_len]), tf.reshape(v_0_grad_theta_0_wrong, [param_len, 1]), ) second_order0 = flatgrad(multiply0, mainQN[0].parameters) second_order1 = flatgrad(multiply1, mainQN[1].parameters) mainQN[0].R1 = R_1 mainQN[1].R1 = R_2 mainQN[0].v = v_0 mainQN[1].v = v_1 mainQN[0].delta = v_0_grad_theta_0 mainQN[1].delta = v_1_grad_theta_1 mainQN[0].delta += tf.multiply(second_order0, mainQN[0].lr_correction) mainQN[1].delta += tf.multiply(second_order1, mainQN[1].lr_correction)