def run_data_coop_game(seed, swf, N=500): def data_coop_reward(a_C, a_DDO): if a_C == 0 and a_DDO == 0: # both defect return np.array([1.]), np.array([1.]) elif a_C == 0 and a_DDO == 1: return np.array([6.]), np.array([0.]) elif a_C == 1 and a_DDO == 0: return np.array([0.]), np.array([6.]) else: return np.array([5.]), np.array([5.]) rng = jax.random.PRNGKey(seed) grad_PG_loss = jit(grad(rlax.policy_gradient_loss)) w_t = np.array([1.]) log = False d = 2 rng, iter_rng = jax.random.split(rng) logits_C = 0.05 * jax.random.normal(iter_rng, shape=(1, d)) rng, iter_rng = jax.random.split(rng) logits_DDO = 0.05 * jax.random.normal(iter_rng, shape=(1, d)) r_Cs = [] r_DDOs = [] for _ in range(N): # sample actions given policies rng, iter_rng = jax.random.split(rng) a_C = jax.random.categorical(iter_rng, logits_C) rng, iter_rng = jax.random.split(rng) a_DDO = jax.random.categorical(iter_rng, logits_DDO) # observe rewards r_C, r_DDO = data_coop_reward(a_C, a_DDO) r_Cs.append(r_C) r_DDOs.append(r_DDO) # update policies logits_C -= 0.01 * grad_PG_loss(logits_C, a_C, r_C, w_t) logits_DDO -= 0.01 * grad_PG_loss(logits_DDO, a_DDO, r_DDO, w_t) if log: print('C', rlax.policy_gradient_loss(logits_C, a_C, r_C, w_t)) print('DDO', rlax.policy_gradient_loss(logits_DDO, a_DDO, r_DDO, w_t)) print('SU', 0.5 * (r_C + r_DDO)) print(logits_C, logits_DDO) #print(.5 * (np.mean(np.array(r_Cs)) + np.mean(np.array(r_DDOs)))) return swf(np.array(r_Cs), np.array(r_DDOs))
def run_simple_RL(): def simple_reward(action): if action == 0: return np.array([1.]) else: return np.array([0.]) rng = jax.random.PRNGKey(0) grad_PG_loss = jit(grad(rlax.policy_gradient_loss)) w_t = np.array([1.]) d = 4 rng, iter_rng = jax.random.split(rng) logits = jax.random.normal(iter_rng, shape=(1, d)) N = 100 for _ in range(N): # sample action given policy rng, iter_rng = jax.random.split(rng) a = jax.random.categorical(iter_rng, logits) # observe reward r = simple_reward(a) # update policy logits -= 0.1 * grad_PG_loss(logits, a, r, w_t) print(rlax.policy_gradient_loss(logits, a, r, w_t)) print(logits)
def loss(trajectory: buffer.Trajectory, rnn_unroll_state: RNNState): """"Computes a linear combination of the policy gradient loss and value loss and regularizes it with an entropy term.""" inputs = pack(trajectory) # Dyanmically unroll the network. This Haiku utility function unpacks the # list of input tensors such that the i^{th} row from each input tensor # is presented to the i^{th} unrolled RNN module. (logits, values, _, _, state_embeddings), new_rnn_unroll_state = hk.dynamic_unroll( network, inputs, rnn_unroll_state) trajectory_len = trajectory.actions.shape[0] # Compute the combined loss given the output of the model. td_errors = rlax.td_lambda(v_tm1=values[:-1, 0], r_t=jnp.squeeze(trajectory.rewards, -1), discount_t=trajectory.discounts * discount, v_t=values[1:, 0], lambda_=jnp.array(td_lambda)) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1, 0], a_t=jnp.squeeze(trajectory.actions, 1), adv_t=td_errors, w_t=jnp.ones(trajectory_len)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1, 0], jnp.ones(trajectory_len))) combined_loss = (actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss) return combined_loss, new_rnn_unroll_state
def loss_fn(params: hk.Params, sample: reverb.ReplaySample) -> jnp.DeviceArray: """Batched, entropy-regularised actor-critic loss with V-trace.""" # Extract the data. data = sample.data observations, actions, rewards, discounts, extra = (data.observation, data.action, data.reward, data.discount, data.extras) initial_state = tree.map_structure(lambda s: s[0], extra['core_state']) behaviour_logits = extra['logits'] # Apply reward clipping. rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) # Unroll current policy over observations. (logits, values), _ = unroll_fn(params, observations, initial_state) # Compute importance sampling weights: current policy / behavior policy. rhos = rlax.categorical_importance_sampling_ratios(logits[:-1], behaviour_logits[:-1], actions[:-1]) # Critic loss. vtrace_returns = rlax.vtrace_td_error_and_advantage( v_tm1=values[:-1], v_t=values[1:], r_t=rewards[:-1], discount_t=discounts[:-1] * discount, rho_tm1=rhos) critic_loss = jnp.square(vtrace_returns.errors) # Policy gradient loss. policy_gradient_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=actions[:-1], adv_t=vtrace_returns.pg_advantage, w_t=jnp.ones_like(rewards[:-1])) # Entropy regulariser. entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards[:-1])) # Combine weighted sum of actor & critic losses, averaged over the sequence. mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss + entropy_cost * entropy_loss) # [] metrics = { 'policy_loss': jnp.mean(policy_gradient_loss), 'critic_loss': jnp.mean(baseline_cost * critic_loss), 'entropy_loss': jnp.mean(entropy_cost * entropy_loss), 'entropy': jnp.mean(entropy_loss), } return mean_loss, metrics
def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.ndarray: """Entropy-regularised actor-critic loss.""" # Extract the data. observations, actions, rewards, discounts, extra = sample.data initial_state = tree.map_structure(lambda s: s[0], extra['core_state']) behaviour_logits = extra['logits'] # actions = actions[:-1] # [T-1] rewards = rewards[:-1] # [T-1] discounts = discounts[:-1] # [T-1] rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) # Unroll current policy over observations. net = functools.partial(network.apply, params) (logits, values), _ = hk.static_unroll(net, observations, initial_state) # Compute importance sampling weights: current policy / behavior policy. rhos = rlax.categorical_importance_sampling_ratios( logits[:-1], behaviour_logits[:-1], actions) # Critic loss. vtrace_returns = rlax.vtrace_td_error_and_advantage( v_tm1=values[:-1], v_t=values[1:], r_t=rewards, discount_t=discounts * discount, rho_t=rhos) critic_loss = jnp.square(vtrace_returns.errors) # Policy gradient loss. policy_gradient_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=actions, adv_t=vtrace_returns.pg_advantage, w_t=jnp.ones_like(rewards)) # Entropy regulariser. entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards)) # Combine weighted sum of actor & critic losses. mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss + entropy_cost * entropy_loss) return mean_loss
def loss(trajectory: sequence.Trajectory) -> jnp.ndarray: """"Actor-critic loss.""" logits, values = network(trajectory.observations) td_errors = rlax.td_lambda( v_tm1=values[:-1], r_t=trajectory.rewards, discount_t=trajectory.discounts * discount, v_t=values[1:], lambda_=jnp.array(td_lambda), ) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=trajectory.actions, adv_t=td_errors, w_t=jnp.ones_like(td_errors)) return actor_loss + critic_loss
def loss(trajectory: buffer.Trajectory) -> jnp.ndarray: """"Actor-critic loss.""" observations, rewards, actions = pack(trajectory) logits, values, _, _, _ = network(observations, rewards, actions) td_errors = rlax.td_lambda(v_tm1=values[:-1], r_t=jnp.squeeze(trajectory.rewards, -1), discount_t=trajectory.discounts * discount, v_t=values[1:], lambda_=jnp.array(td_lambda)) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=jnp.squeeze(trajectory.actions, 1), adv_t=td_errors, w_t=jnp.ones_like(td_errors)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1], jnp.ones_like(td_errors))) return actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss
def loss( weights, observations, actions, rewards, td_lambda=0.2, discount=0.99, policy_cost=0.25, entropy_cost=1e-3, ): """Actor-critic loss.""" logits, values = network(weights, observations) values = jnp.append(values, jnp.sum(rewards)) # replace -inf values by tiny finite value logits = jnp.maximum(logits, MINIMUM_LOGIT) td_errors = rlax.td_lambda( v_tm1=values[:-1], r_t=rewards, discount_t=jnp.full_like(rewards, discount), v_t=values[1:], lambda_=td_lambda, ) critic_loss = jnp.mean(td_errors ** 2) if type_ == "a2c": actor_loss = rlax.policy_gradient_loss( logits_t=logits, a_t=actions, adv_t=td_errors, w_t=jnp.ones(td_errors.shape[0]), ) elif type_ == "supervised": actor_loss = jnp.mean(cross_entropy(logits, actions)) entropy_loss = -jnp.mean(entropy(logits)) return policy_cost * actor_loss, critic_loss, entropy_cost * entropy_loss
def loss(trajectory: sequence.Trajectory, rnn_unroll_state: LSTMState): """"Actor-critic loss.""" (logits, values), new_rnn_unroll_state = hk.dynamic_unroll( network, trajectory.observations[:, None, ...], rnn_unroll_state) seq_len = trajectory.actions.shape[0] td_errors = rlax.td_lambda( v_tm1=values[:-1, 0], r_t=trajectory.rewards, discount_t=trajectory.discounts * discount, v_t=values[1:, 0], lambda_=jnp.array(td_lambda), ) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1, 0], a_t=trajectory.actions, adv_t=td_errors, w_t=jnp.ones(seq_len)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1, 0], jnp.ones(seq_len))) combined_loss = actor_loss + critic_loss + entropy_cost * entropy_loss return combined_loss, new_rnn_unroll_state
def run_data_coop_game_with_regulator(seed, swf, N=500): def data_coop_reward(a_C, a_DDO): if a_C == 0 and a_DDO == 0: # both defect return np.array([1.]), np.array([1.]) elif a_C == 0 and a_DDO == 1: return np.array([6.]), np.array([0.]) elif a_C == 1 and a_DDO == 0: return np.array([0.]), np.array([6.]) else: return np.array([5.]), np.array([5.]) def redistribute(r_C, r_DDO, a_R): tax = 0. if a_R == 0: tax = 0. elif a_R == 1: tax = 0.15 elif a_R == 2: tax = 0.3 else: tax = 0.5 wealth = tax * (r_C + r_DDO) r_C = r_C - tax * r_C + wealth / 2. r_DDO = r_DDO - tax * r_DDO + wealth / 2. return r_C, r_DDO, tax def redistribute(r_C, r_DDO, a_R1, a_R2): tax1 = 0. if a_R1 == 0: tax1 = 0. elif a_R1 == 1: tax1 = 0.15 elif a_R1 == 2: tax1 = 0.3 else: tax1 = 0.5 tax2 = 0. if a_R2 == 0: tax2 = 0. elif a_R2 == 1: tax2 = 0.15 elif a_R2 == 2: tax2 = 0.3 else: tax2 = 0.5 wealth = tax1 * r_C + tax2 * r_DDO r_C = r_C - tax1 * r_C + wealth / 2. r_DDO = r_DDO - tax2 * r_DDO + wealth / 2. return r_C, r_DDO, tax1, tax2 rng = jax.random.PRNGKey(seed) grad_PG_loss = jit(grad(rlax.policy_gradient_loss)) w_t = np.array([1.]) log = False d = 2 rng, iter_rng = jax.random.split(rng) logits_C = 0.1 * np.array([[1, 1.]]) rng, iter_rng = jax.random.split(rng) logits_DDO = 0.1 * np.array([[1, 1.]]) rng, iter_rng = jax.random.split(rng) logits_R1 = 0.1 * np.array([[1, 1, 1, 1.]]) logits_R2 = 0.1 * np.array([[1, 1, 1, 1.]]) r_Cs = [] r_DDOs = [] taxes1 = [] taxes2 = [] for i in range(N): # sample actions given policies rng, iter_rng = jax.random.split(rng) a_C = jax.random.categorical(iter_rng, logits_C) rng, iter_rng = jax.random.split(rng) a_DDO = jax.random.categorical(iter_rng, logits_DDO) rng, iter_rng = jax.random.split(rng) a_R1 = jax.random.categorical(iter_rng, logits_R1) a_R2 = jax.random.categorical(iter_rng, logits_R1) # observe rewards r_C, r_DDO = data_coop_reward(a_C, a_DDO) r_Cs.append(r_C) r_DDOs.append(r_DDO) r_C, r_DDO, tax1, tax2 = redistribute(r_C, r_DDO, a_R1, a_R2) taxes1.append(tax1) taxes2.append(tax2) # update policies logits_C -= 0.01 * grad_PG_loss(logits_C, a_C, r_C, w_t) logits_DDO -= 0.01 * grad_PG_loss(logits_DDO, a_DDO, r_DDO, w_t) lag = 50 if i % lag == 1: R = np.array(r_Cs[-lag:]).mean() + np.array(r_DDOs[-lag:]).mean() logits_R1 -= 0.005 * grad_PG_loss(logits_R1, a_R1, .5 * np.array([R]), w_t) logits_R2 -= 0.005 * grad_PG_loss(logits_R2, a_R2, .5 * np.array([R]), w_t) if log: print('C', rlax.policy_gradient_loss(logits_C, a_C, r_C, w_t)) print('DDO', rlax.policy_gradient_loss(logits_DDO, a_DDO, r_DDO, w_t)) print('SU', 0.5 * (r_C + r_DDO)) print('logits:', logits_C, logits_DDO, logits_R1, logits_R2) print('mean SU:', .5 * (np.mean(np.array(r_Cs)) + np.mean(np.array(r_DDOs)))) print('mean tax1', np.array(taxes1).mean()) print('mean tax2', np.array(taxes2).mean()) return swf(np.array(r_Cs), np.array(r_DDOs))
def run_data_coop_game_with_gaussian_regulator(seed, swf, N=500): def data_coop_reward(a_C, a_DDO): if a_C == 0 and a_DDO == 0: # both defect return np.array([1.]), np.array([1.]) elif a_C == 0 and a_DDO == 1: return np.array([6.]), np.array([0.]) elif a_C == 1 and a_DDO == 0: return np.array([0.]), np.array([6.]) else: return np.array([5.]), np.array([5.]) def gaussian_logprob(logits, a): return np.mean(-((a - logits) / .1)**2) def redistribute(r_C, r_DDO, a_R1, a_R2): tax1 = 0.5 * jax.nn.sigmoid(a_R1) tax2 = 0.5 * jax.nn.sigmoid(a_R2) wealth = tax1 * r_C + tax2 * r_DDO r_C = r_C - tax1 * r_C + wealth / 2. r_DDO = r_DDO - tax2 * r_DDO + wealth / 2. return r_C, r_DDO, tax1, tax2 def redistributed(r_C, r_DDO, a_R1, a_R2): tax1 = 0.5 * jax.nn.sigmoid(a_R1) tax2 = 0.5 * jax.nn.sigmoid(a_R2) wealth = tax1 * (r_C + r_DDO) r_C = r_C - tax1 * r_C + wealth / 2. r_DDO = r_DDO - tax1 * r_DDO + wealth / 2. return r_C, r_DDO, tax1, tax2 rng = jax.random.PRNGKey(seed) grad_PG_loss = jit(grad(rlax.policy_gradient_loss)) w_t = np.array([1.]) log = False d = 2 rng, iter_rng = jax.random.split(rng) logits_C = np.array([[1, 1.]]) rng, iter_rng = jax.random.split(rng) logits_DDO = np.array([[1, 1.]]) rng, iter_rng = jax.random.split(rng) logits_R1 = np.array([1.]) # the mean of the Gaussian logits_R2 = np.array([1.]) # the mean of the Gaussian r_Cs = [] r_DDOs = [] taxes1 = [] taxes2 = [] for i in range(N): # sample actions given policies rng, iter_rng = jax.random.split(rng) a_C = jax.random.categorical(iter_rng, logits_C) rng, iter_rng = jax.random.split(rng) a_DDO = jax.random.categorical(iter_rng, logits_DDO) rng, iter_rng = jax.random.split(rng) a_R1 = 0.1 * jax.random.normal(iter_rng) + logits_R1 rng, iter_rng = jax.random.split(rng) a_R2 = 0.1 * jax.random.normal(iter_rng) + logits_R2 # observe rewards r_C, r_DDO = data_coop_reward(a_C, a_DDO) r_Cs.append(r_C) r_DDOs.append(r_DDO) r_C, r_DDO, tax1, tax2 = redistribute(r_C, r_DDO, a_R1, a_R2) taxes1.append(tax1) taxes2.append(tax2) # update policies logits_C -= 0.01 * grad_PG_loss(logits_C, a_C, r_C, w_t) logits_DDO -= 0.01 * grad_PG_loss(logits_DDO, a_DDO, r_DDO, w_t) lag = 50 if i > 0: if i % lag == 0: R = np.array(r_Cs[-lag:]).mean() + np.array( r_DDOs[-lag:]).mean() logits_R1 -= 0.005 * R * grad(gaussian_logprob)(logits_R1, a_R1) logits_R2 -= 0.005 * R * grad(gaussian_logprob)(logits_R2, a_R2) if log: print('C', rlax.policy_gradient_loss(logits_C, a_C, r_C, w_t)) print('DDO', rlax.policy_gradient_loss(logits_DDO, a_DDO, r_DDO, w_t)) print('SU', 0.5 * (r_C + r_DDO)) print('logits:', logits_C, logits_DDO, logits_R1, logits_R2) print('mean SU:', .5 * (np.mean(np.array(r_Cs)) + np.mean(np.array(r_DDOs)))) print('mean tax1', np.array(taxes1).mean()) print('mean tax2', np.array(taxes2).mean()) return 0.5 * (np.array(r_Cs) + np.array(r_DDOs)) return swf(np.array(r_Cs), np.array(r_DDOs))